[jvm-packages] XGBoost Spark integration refactor (#3387)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * [jvm-packages] XGBoost Spark integration refactor. (#3313) * XGBoost Spark integration refactor. * Make corresponding update for xgboost4j-example * Address comments. * [jvm-packages] Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib (#3326) * Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib * Fix extra space. * [jvm-packages] XGBoost Spark supports ranking with group data. (#3369) * XGBoost Spark supports ranking with group data. * Use Iterator.duplicate to prevent OOM. * Update CheckpointManagerSuite.scala * Resolve conflicts
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.SparkContext
|
||||
@@ -63,9 +64,9 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
val version = versions.max
|
||||
val fullPath = getPath(version)
|
||||
logger.info(s"Start training from previous booster at $fullPath")
|
||||
val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc)
|
||||
model.booster.booster.setVersion(version)
|
||||
model.booster
|
||||
val booster = SXGBoost.loadModel(fullPath)
|
||||
booster.booster.setVersion(version)
|
||||
booster
|
||||
} else {
|
||||
null
|
||||
}
|
||||
@@ -76,12 +77,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
*
|
||||
* @param checkpoint the checkpoint to save as an XGBoostModel
|
||||
*/
|
||||
private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = {
|
||||
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
||||
val fullPath = getPath(checkpoint.version)
|
||||
logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath")
|
||||
checkpoint.saveModelAsHadoopFile(fullPath)(sc)
|
||||
val fullPath = getPath(checkpoint.getVersion)
|
||||
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
|
||||
checkpoint.saveModel(fullPath)
|
||||
prevModelPaths.foreach(path => fs.delete(path, true))
|
||||
}
|
||||
|
||||
|
||||
@@ -21,16 +21,15 @@ import java.nio.file.Files
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
|
||||
|
||||
@@ -134,7 +133,7 @@ object XGBoost extends Serializable {
|
||||
fromBaseMarginsToArray(baseMargins), cacheDirName)
|
||||
|
||||
try {
|
||||
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.train, params, round,
|
||||
@@ -148,89 +147,6 @@ object XGBoost extends Serializable {
|
||||
}.cache()
|
||||
}
|
||||
|
||||
/**
|
||||
* Train XGBoost model with the DataFrame-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as DataFrame
|
||||
* @param params Map containing the parameters to configure XGBoost
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing The value which represents a missing value in the dataset
|
||||
* @param featureCol the name of input column, "features" as default value
|
||||
* @param labelCol the name of output column, "label" as default value
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def trainWithDataFrame(
|
||||
trainingData: Dataset[_],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN,
|
||||
featureCol: String = "features",
|
||||
labelCol: String = "label"): XGBoostModel = {
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
val estimator = new XGBoostEstimator(params)
|
||||
// assigning general parameters
|
||||
estimator.
|
||||
set(estimator.useExternalMemory, useExternalMemory).
|
||||
set(estimator.round, round).
|
||||
set(estimator.nWorkers, nWorkers).
|
||||
set(estimator.customObj, obj).
|
||||
set(estimator.customEval, eval).
|
||||
set(estimator.missing, missing).
|
||||
setFeaturesCol(featureCol).
|
||||
setLabelCol(labelCol).
|
||||
fit(trainingData)
|
||||
}
|
||||
|
||||
private[spark] def isClassificationTask(params: Map[String, Any]): Boolean = {
|
||||
val objective = params.getOrElse("objective", params.getOrElse("obj_type", null))
|
||||
objective != null && {
|
||||
val objStr = objective.toString
|
||||
objStr != "regression" && !objStr.startsWith("reg:") && objStr != "count:poisson" &&
|
||||
!objStr.startsWith("rank:")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train XGBoost model with the RDD-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as RDD
|
||||
* @param params Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@deprecated("Use XGBoost.trainWithRDD instead.")
|
||||
def train(
|
||||
trainingData: RDD[MLLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): XGBoostModel = {
|
||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||
}
|
||||
|
||||
private def overrideParamsAccordingToTaskCPUs(
|
||||
params: Map[String, Any],
|
||||
sc: SparkContext): Map[String, Any] = {
|
||||
@@ -259,39 +175,8 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* Train XGBoost model with the RDD-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as RDD
|
||||
* @param params Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing The value which represents a missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed
|
||||
* @return XGBoostModel when successful training
|
||||
* @return A tuple of the booster and the metrics used to build training summary
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def trainWithRDD(
|
||||
trainingData: RDD[MLLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): XGBoostModel = {
|
||||
import DataUtils._
|
||||
val xgbTrainingData = trainingData.map { case MLLabeledPoint(label, features) =>
|
||||
features.asXGB.copy(label = label.toFloat)
|
||||
}
|
||||
trainDistributed(xgbTrainingData, params, round, nWorkers, obj, eval,
|
||||
useExternalMemory, missing)
|
||||
}
|
||||
|
||||
@throws(classOf[XGBoostError])
|
||||
private[spark] def trainDistributed(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
@@ -301,7 +186,7 @@ object XGBoost extends Serializable {
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): XGBoostModel = {
|
||||
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
@@ -350,20 +235,15 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
||||
sparkJobThread, isClsTask)
|
||||
if (isClsTask){
|
||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||
params.getOrElse("num_class", "2").toString.toInt
|
||||
}
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
||||
sparkJobThread)
|
||||
if (checkpointRound < round) {
|
||||
prevBooster = model.booster
|
||||
checkpointManager.updateCheckpoint(model)
|
||||
prevBooster = booster
|
||||
checkpointManager.updateCheckpoint(prevBooster)
|
||||
}
|
||||
model
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
@@ -383,17 +263,14 @@ object XGBoost extends Serializable {
|
||||
private def postTrackerReturnProcessing(
|
||||
trackerReturnVal: Int,
|
||||
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
|
||||
sparkJobThread: Thread,
|
||||
isClassificationTask: Boolean
|
||||
): XGBoostModel = {
|
||||
sparkJobThread: Thread): (Booster, Map[String, Array[Float]]) = {
|
||||
if (trackerReturnVal == 0) {
|
||||
// Copies of the final booster and the corresponding metrics
|
||||
// reside in each partition of the `distributedBoostersAndMetrics`.
|
||||
// Any of them can be used to create the model.
|
||||
val (booster, metrics) = distributedBoostersAndMetrics.first()
|
||||
val xgboostModel = XGBoostModel(booster, isClassificationTask)
|
||||
distributedBoostersAndMetrics.unpersist(false)
|
||||
xgboostModel.setSummary(XGBoostTrainingSummary(metrics))
|
||||
(booster, metrics)
|
||||
} else {
|
||||
try {
|
||||
if (sparkJobThread.isAlive) {
|
||||
@@ -407,64 +284,6 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private def loadGeneralModelParams(inputStream: FSDataInputStream): (String, String, String) = {
|
||||
val featureCol = inputStream.readUTF()
|
||||
val labelCol = inputStream.readUTF()
|
||||
val predictionCol = inputStream.readUTF()
|
||||
(featureCol, labelCol, predictionCol)
|
||||
}
|
||||
|
||||
private def setGeneralModelParams(
|
||||
featureCol: String,
|
||||
labelCol: String,
|
||||
predCol: String,
|
||||
xgBoostModel: XGBoostModel): XGBoostModel = {
|
||||
xgBoostModel.setFeaturesCol(featureCol)
|
||||
xgBoostModel.setLabelCol(labelCol)
|
||||
xgBoostModel.setPredictionCol(predCol)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path in HDFS-compatible file system
|
||||
*
|
||||
* @param modelPath The path of the file representing the model
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoopFile(modelPath: String)(implicit sparkContext: SparkContext):
|
||||
XGBoostModel = {
|
||||
val path = new Path(modelPath)
|
||||
val dataInStream = path.getFileSystem(sparkContext.hadoopConfiguration).open(path)
|
||||
val modelType = dataInStream.readUTF()
|
||||
val (featureCol, labelCol, predictionCol) = loadGeneralModelParams(dataInStream)
|
||||
modelType match {
|
||||
case "_cls_" =>
|
||||
val rawPredictionCol = dataInStream.readUTF()
|
||||
val numClasses = dataInStream.readInt()
|
||||
val thresholdLength = dataInStream.readInt()
|
||||
var thresholds: Array[Double] = null
|
||||
if (thresholdLength != -1) {
|
||||
thresholds = new Array[Double](thresholdLength)
|
||||
for (i <- 0 until thresholdLength) {
|
||||
thresholds(i) = dataInStream.readDouble()
|
||||
}
|
||||
}
|
||||
val xgBoostModel = new XGBoostClassificationModel(SXGBoost.loadModel(dataInStream))
|
||||
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel).
|
||||
asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(rawPredictionCol)
|
||||
if (thresholdLength != -1) {
|
||||
xgBoostModel.setThresholds(thresholds)
|
||||
}
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClasses
|
||||
xgBoostModel
|
||||
case "_reg_" =>
|
||||
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
||||
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
|
||||
case other =>
|
||||
throw new XGBoostError(s"Unknown model type $other. Supported types " +
|
||||
s"are: ['_reg_', '_cls_'].")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class Watches private(
|
||||
@@ -489,12 +308,29 @@ private class Watches private(
|
||||
|
||||
private object Watches {
|
||||
|
||||
def buildGroups(groups: Seq[Int]): Seq[Int] = {
|
||||
val output = mutable.ArrayBuffer.empty[Int]
|
||||
var count = 1
|
||||
var i = 1
|
||||
while (i < groups.length) {
|
||||
if (groups(i) != groups(i - 1)) {
|
||||
output += count
|
||||
count = 1
|
||||
} else {
|
||||
count += 1
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
output += count
|
||||
output
|
||||
}
|
||||
|
||||
def apply(
|
||||
params: Map[String, Any],
|
||||
labeledPoints: Iterator[XGBLabeledPoint],
|
||||
baseMarginsOpt: Option[Array[Float]],
|
||||
cacheDirName: Option[String]): Watches = {
|
||||
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||
val r = new Random(seed)
|
||||
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
||||
@@ -506,8 +342,18 @@ private object Watches {
|
||||
|
||||
accepted
|
||||
}
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||
|
||||
val (trainIter1, trainIter2) = trainPoints.duplicate
|
||||
val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull)
|
||||
val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray
|
||||
trainMatrix.setGroup(trainGroups)
|
||||
|
||||
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
||||
if (trainTestRatio < 1.0) {
|
||||
val testGroups = buildGroups(testPoints.map(_.group)).toArray
|
||||
testMatrix.setGroup(testGroups)
|
||||
}
|
||||
|
||||
r.setSeed(seed)
|
||||
for (baseMargins <- baseMarginsOpt) {
|
||||
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||
@@ -515,11 +361,6 @@ private object Watches {
|
||||
testMatrix.setBaseMargin(testMargin)
|
||||
}
|
||||
|
||||
// TODO: use group attribute from the points.
|
||||
if (params.contains("groupData") && params("groupData") != null) {
|
||||
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||
TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||
import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, Param, ParamMap}
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
/**
|
||||
* class of the XGBoost model used for classification task
|
||||
*/
|
||||
class XGBoostClassificationModel private[spark](
|
||||
override val uid: String, booster: Booster)
|
||||
extends XGBoostModel(booster) {
|
||||
|
||||
def this(booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), booster)
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
// scalastyle:off
|
||||
|
||||
/**
|
||||
* whether to output raw margin
|
||||
*/
|
||||
final val outputMargin = new BooleanParam(this, "outputMargin", "whether to output untransformed margin value")
|
||||
|
||||
setDefault(outputMargin, false)
|
||||
|
||||
def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel]
|
||||
|
||||
/**
|
||||
* the name of the column storing the raw prediction value, either probabilities (as default) or
|
||||
* raw margin value
|
||||
*/
|
||||
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).")
|
||||
|
||||
setDefault(rawPredictionCol, "probabilities")
|
||||
|
||||
final def getRawPredictionCol: String = $(rawPredictionCol)
|
||||
|
||||
def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel]
|
||||
|
||||
/**
|
||||
* Thresholds in multi-class classification
|
||||
*/
|
||||
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
|
||||
|
||||
def getThresholds: Array[Double] = $(thresholds)
|
||||
|
||||
def setThresholds(value: Array[Double]): XGBoostClassificationModel =
|
||||
set(thresholds, value).asInstanceOf[XGBoostClassificationModel]
|
||||
|
||||
// scalastyle:on
|
||||
|
||||
// generate dataframe containing raw prediction column which is typed as Vector
|
||||
private def predictRaw(
|
||||
testSet: Dataset[_],
|
||||
temporalColName: Option[String] = None,
|
||||
forceTransformedScore: Option[Boolean] = None): DataFrame = {
|
||||
val predictRDD = produceRowRDD(testSet, forceTransformedScore.getOrElse($(outputMargin)))
|
||||
val colName = temporalColName.getOrElse($(rawPredictionCol))
|
||||
val tempColName = colName + "_arraytype"
|
||||
val dsWithArrayTypedRawPredCol = testSet.sparkSession.createDataFrame(predictRDD, schema = {
|
||||
testSet.schema.add(tempColName, ArrayType(FloatType, containsNull = false))
|
||||
})
|
||||
val transformerForProbabilitiesArray =
|
||||
(rawPredArray: mutable.WrappedArray[Float]) =>
|
||||
if (numClasses == 2) {
|
||||
Array(1 - rawPredArray(0), rawPredArray(0)).map(_.toDouble)
|
||||
} else {
|
||||
rawPredArray.map(_.toDouble).array
|
||||
}
|
||||
dsWithArrayTypedRawPredCol.withColumn(colName,
|
||||
udf((rawPredArray: mutable.WrappedArray[Float]) =>
|
||||
new MLDenseVector(transformerForProbabilitiesArray(rawPredArray))).apply(col(tempColName))).
|
||||
drop(tempColName)
|
||||
}
|
||||
|
||||
private def fromFeatureToPrediction(testSet: Dataset[_]): Dataset[_] = {
|
||||
val rawPredictionDF = predictRaw(testSet, Some("rawPredictionCol"))
|
||||
val predictionUDF = udf(raw2prediction _).apply(col("rawPredictionCol"))
|
||||
val tempDF = rawPredictionDF.withColumn($(predictionCol), predictionUDF)
|
||||
val allColumnNames = testSet.columns ++ Seq($(predictionCol))
|
||||
tempDF.select(allColumnNames(0), allColumnNames.tail: _*)
|
||||
}
|
||||
|
||||
private def argMax(vector: Array[Double]): Double = {
|
||||
vector.zipWithIndex.maxBy(_._1)._2
|
||||
}
|
||||
|
||||
private def raw2prediction(rawPrediction: MLDenseVector): Double = {
|
||||
if (!isDefined(thresholds)) {
|
||||
argMax(rawPrediction.values)
|
||||
} else {
|
||||
probability2prediction(rawPrediction)
|
||||
}
|
||||
}
|
||||
|
||||
private def probability2prediction(probability: MLDenseVector): Double = {
|
||||
if (!isDefined(thresholds)) {
|
||||
argMax(probability.values)
|
||||
} else {
|
||||
val thresholds: Array[Double] = getThresholds
|
||||
val scaledProbability =
|
||||
probability.values.zip(thresholds).map { case (p, t) =>
|
||||
if (t == 0.0) Double.PositiveInfinity else p / t
|
||||
}
|
||||
argMax(scaledProbability)
|
||||
}
|
||||
}
|
||||
|
||||
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
|
||||
transformSchema(testSet.schema, logging = true)
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".transform() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
if ($(outputMargin)) {
|
||||
setRawPredictionCol("margin")
|
||||
}
|
||||
var outputData = testSet
|
||||
var numColsOutput = 0
|
||||
if ($(rawPredictionCol).nonEmpty) {
|
||||
outputData = predictRaw(testSet)
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
if ($(rawPredictionCol).nonEmpty) {
|
||||
require(!$(outputMargin), "XGBoost does not support output final prediction with" +
|
||||
" untransformed margin. Please set predictionCol as \"\" when setting outputMargin as" +
|
||||
" true")
|
||||
val rawToPredUDF = udf(raw2prediction _).apply(col($(rawPredictionCol)))
|
||||
outputData = outputData.withColumn($(predictionCol), rawToPredUDF)
|
||||
} else {
|
||||
outputData = fromFeatureToPrediction(testSet)
|
||||
}
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if (numColsOutput == 0) {
|
||||
this.logWarning(s"$uid: XGBoostClassificationModel.transform() was called as NOOP" +
|
||||
" since no output columns were set.")
|
||||
}
|
||||
outputData.toDF()
|
||||
}
|
||||
|
||||
private[spark] var numOfClasses = 2
|
||||
|
||||
def numClasses: Int = numOfClasses
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
||||
val newModel = copyValues(new XGBoostClassificationModel(booster), extra)
|
||||
newModel.setSummary(summary)
|
||||
}
|
||||
|
||||
override protected def predict(features: MLVector): Double = {
|
||||
throw new Exception("XGBoost does not support online prediction ")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.classification._
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
||||
|
||||
class XGBoostClassifier (
|
||||
override val uid: String,
|
||||
private val xgboostParams: Map[String, Any])
|
||||
extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
|
||||
with XGBoostClassifierParams with DefaultParamsWritable {
|
||||
|
||||
def this() = this(Identifiable.randomUID("xgbc"), Map[String, Any]())
|
||||
|
||||
def this(uid: String) = this(uid, Map[String, Any]())
|
||||
|
||||
def this(xgboostParams: Map[String, Any]) = this(
|
||||
Identifiable.randomUID("xgbc"), xgboostParams)
|
||||
|
||||
XGBoostToMLlibParams(xgboostParams)
|
||||
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
|
||||
|
||||
def setNumClass(value: Int): this.type = set(numClass, value)
|
||||
|
||||
// setters for general params
|
||||
def setNumRound(value: Int): this.type = set(numRound, value)
|
||||
|
||||
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
|
||||
|
||||
def setNthread(value: Int): this.type = set(nthread, value)
|
||||
|
||||
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
|
||||
|
||||
def setSilent(value: Int): this.type = set(silent, value)
|
||||
|
||||
def setMissing(value: Float): this.type = set(missing, value)
|
||||
|
||||
def setTimeoutRequestWorkers(value: Long): this.type = set(timeoutRequestWorkers, value)
|
||||
|
||||
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
|
||||
|
||||
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
|
||||
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
// setters for booster params
|
||||
def setBooster(value: String): this.type = set(booster, value)
|
||||
|
||||
def setEta(value: Double): this.type = set(eta, value)
|
||||
|
||||
def setGamma(value: Double): this.type = set(gamma, value)
|
||||
|
||||
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
|
||||
|
||||
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
|
||||
|
||||
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
|
||||
|
||||
def setSubsample(value: Double): this.type = set(subsample, value)
|
||||
|
||||
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
|
||||
|
||||
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
|
||||
|
||||
def setLambda(value: Double): this.type = set(lambda, value)
|
||||
|
||||
def setAlpha(value: Double): this.type = set(alpha, value)
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
|
||||
def setSketchEps(value: Double): this.type = set(sketchEps, value)
|
||||
|
||||
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
|
||||
|
||||
def setSampleType(value: String): this.type = set(sampleType, value)
|
||||
|
||||
def setNormalizeType(value: String): this.type = set(normalizeType, value)
|
||||
|
||||
def setRateDrop(value: Double): this.type = set(rateDrop, value)
|
||||
|
||||
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
|
||||
|
||||
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
|
||||
|
||||
// setters for learning params
|
||||
def setObjective(value: String): this.type = set(objective, value)
|
||||
|
||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
||||
|
||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
||||
|
||||
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
|
||||
|
||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
if ($(objective).startsWith("multi")) {
|
||||
// multi
|
||||
"merror"
|
||||
} else {
|
||||
// binary
|
||||
"error"
|
||||
}
|
||||
}
|
||||
|
||||
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
|
||||
|
||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
||||
set(evalMetric, setupDefaultEvalMetric())
|
||||
}
|
||||
|
||||
val _numClasses = getNumClasses(dataset)
|
||||
if (isDefined(numClass) && $(numClass) != _numClasses) {
|
||||
throw new Exception("The number of classes in dataset doesn't match " +
|
||||
"\'num_class\' in xgboost params.")
|
||||
}
|
||||
|
||||
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
||||
lit(Float.NaN)
|
||||
} else {
|
||||
col($(baseMarginCol))
|
||||
}
|
||||
|
||||
val instances: RDD[XGBLabeledPoint] = dataset.select(
|
||||
col($(featuresCol)),
|
||||
col($(labelCol)).cast(FloatType),
|
||||
baseMargin.cast(FloatType),
|
||||
weight.cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, baseMargin = baseMargin, weight = weight)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing))
|
||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
model
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostClassifier = defaultCopy(extra)
|
||||
}
|
||||
|
||||
object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
|
||||
|
||||
override def load(path: String): XGBoostClassifier = super.load(path)
|
||||
}
|
||||
|
||||
class XGBoostClassificationModel private[ml](
|
||||
override val uid: String,
|
||||
override val numClasses: Int,
|
||||
private[spark] val _booster: Booster)
|
||||
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
|
||||
with XGBoostClassifierParams with MLWritable with Serializable {
|
||||
|
||||
import XGBoostClassificationModel._
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, 2, null)
|
||||
|
||||
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
||||
|
||||
/**
|
||||
* Returns summary (e.g. train/test objective history) of model on the
|
||||
* training set. An exception is thrown if no summary is available.
|
||||
*/
|
||||
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
||||
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
||||
}
|
||||
|
||||
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
||||
trainingSummary = Some(summary)
|
||||
this
|
||||
}
|
||||
|
||||
// TODO: Make it public after we resolve performance issue
|
||||
private def margin(features: Vector): Array[Float] = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
|
||||
_booster.predict(data = dm, outPutMargin = true)(0)
|
||||
}
|
||||
|
||||
private def probability(features: Vector): Array[Float] = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
|
||||
_booster.predict(data = dm, outPutMargin = false)(0)
|
||||
}
|
||||
|
||||
override def predict(features: Vector): Double = {
|
||||
throw new Exception("XGBoost-Spark does not support online prediction")
|
||||
}
|
||||
|
||||
// Actually we don't use this function at all, to make it pass compiler check.
|
||||
override def predictRaw(features: Vector): Vector = {
|
||||
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
||||
}
|
||||
|
||||
// Actually we don't use this function at all, to make it pass compiler check.
|
||||
override def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
||||
}
|
||||
|
||||
// Generate raw prediction and probability prediction.
|
||||
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
||||
|
||||
val schema = StructType(dataset.schema.fields ++
|
||||
Seq(StructField(name = _rawPredictionCol, dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false)) ++
|
||||
Seq(StructField(name = _probabilityCol, dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||
|
||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = dataset.sparkSession.sparkContext.appName
|
||||
|
||||
val rdd = dataset.rdd.mapPartitions { rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cacheInfo = {
|
||||
if ($(useExternalMemory)) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
|
||||
try {
|
||||
val rawPredictionItr = {
|
||||
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator
|
||||
}
|
||||
val probabilityItr = {
|
||||
bBooster.value.predict(dm, outPutMargin = false).map(Row(_)).iterator
|
||||
}
|
||||
Rabit.shutdown()
|
||||
rowItr1.zip(rawPredictionItr).zip(probabilityItr).map {
|
||||
case ((originals: Row, rawPrediction: Row), probability: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
|
||||
}
|
||||
} finally {
|
||||
dm.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator[Row]()
|
||||
}
|
||||
}
|
||||
|
||||
bBooster.unpersist(blocking = false)
|
||||
|
||||
dataset.sparkSession.createDataFrame(rdd, schema)
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".transform() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
|
||||
// Output selected columns only.
|
||||
// This is a bit complicated since it tries to avoid repeated computation.
|
||||
var outputData = transformInternal(dataset)
|
||||
var numColsOutput = 0
|
||||
|
||||
val rawPredictionUDF = udf { (rawPrediction: mutable.WrappedArray[Float]) =>
|
||||
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
|
||||
}
|
||||
|
||||
val probabilityUDF = udf { (probability: mutable.WrappedArray[Float]) =>
|
||||
if (numClasses == 2) {
|
||||
Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble))
|
||||
} else {
|
||||
Vectors.dense(probability.map(_.toDouble).toArray)
|
||||
}
|
||||
}
|
||||
|
||||
val predictUDF = udf { (probability: mutable.WrappedArray[Float]) =>
|
||||
// From XGBoost probability to MLlib prediction
|
||||
val probabilities = if (numClasses == 2) {
|
||||
Array(1 - probability(0), probability(0)).map(_.toDouble)
|
||||
} else {
|
||||
probability.map(_.toDouble).toArray
|
||||
}
|
||||
probability2prediction(Vectors.dense(probabilities))
|
||||
}
|
||||
|
||||
if ($(rawPredictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if ($(probabilityCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if (numColsOutput == 0) {
|
||||
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
||||
" since no output columns were set.")
|
||||
}
|
||||
outputData
|
||||
.toDF
|
||||
.drop(col(_rawPredictionCol))
|
||||
.drop(col(_probabilityCol))
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
||||
val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra)
|
||||
newModel.setSummary(summary).setParent(parent)
|
||||
}
|
||||
|
||||
override def write: MLWriter =
|
||||
new XGBoostClassificationModel.XGBoostClassificationModelWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
|
||||
|
||||
private val _rawPredictionCol = "_rawPrediction"
|
||||
private val _probabilityCol = "_probability"
|
||||
|
||||
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
|
||||
|
||||
override def load(path: String): XGBoostClassificationModel = super.load(path)
|
||||
|
||||
private[XGBoostClassificationModel]
|
||||
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// Save metadata and Params
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
// Save model data
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||
outputStream.writeInt(instance.numClasses)
|
||||
instance._booster.saveModel(outputStream)
|
||||
outputStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostClassificationModelReader extends MLReader[XGBoostClassificationModel] {
|
||||
|
||||
/** Checked against metadata when loading model */
|
||||
private val className = classOf[XGBoostClassificationModel].getName
|
||||
|
||||
override def load(path: String): XGBoostClassificationModel = {
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
|
||||
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
||||
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
||||
val numClasses = dataInStream.readInt()
|
||||
|
||||
val booster = SXGBoost.loadModel(dataInStream)
|
||||
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
|
||||
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
||||
model
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.Predictor
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.FloatType
|
||||
import org.apache.spark.sql.{Dataset, Row}
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
/**
|
||||
* XGBoost Estimator to produce a XGBoost model
|
||||
*/
|
||||
class XGBoostEstimator private[spark](
|
||||
override val uid: String, xgboostParams: Map[String, Any])
|
||||
extends Predictor[Vector, XGBoostEstimator, XGBoostModel]
|
||||
with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {
|
||||
|
||||
def this(xgboostParams: Map[String, Any]) =
|
||||
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
|
||||
|
||||
def this(uid: String) = this(uid, Map[String, Any]())
|
||||
|
||||
// called in fromXGBParamMapToParams only when eval_metric is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null))
|
||||
if (objFunc == null) {
|
||||
"rmse"
|
||||
} else {
|
||||
// compute default metric based on specified objective
|
||||
val isClassificationTask = XGBoost.isClassificationTask(xgboostParams)
|
||||
if (!isClassificationTask) {
|
||||
// default metric for regression or ranking
|
||||
if (objFunc.toString.startsWith("rank")) {
|
||||
"map"
|
||||
} else {
|
||||
"rmse"
|
||||
}
|
||||
} else {
|
||||
// default metric for classification
|
||||
if (objFunc.toString.startsWith("multi")) {
|
||||
// multi
|
||||
"merror"
|
||||
} else {
|
||||
// binary
|
||||
"error"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def fromXGBParamMapToParams(): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
params.find(_.name == paramName) match {
|
||||
case None =>
|
||||
case Some(_: DoubleParam) =>
|
||||
set(paramName, paramValue.toString.toDouble)
|
||||
case Some(_: BooleanParam) =>
|
||||
set(paramName, paramValue.toString.toBoolean)
|
||||
case Some(_: IntParam) =>
|
||||
set(paramName, paramValue.toString.toInt)
|
||||
case Some(_: FloatParam) =>
|
||||
set(paramName, paramValue.toString.toFloat)
|
||||
case Some(_: Param[_]) =>
|
||||
set(paramName, paramValue)
|
||||
}
|
||||
}
|
||||
if (xgboostParams.get("eval_metric").isEmpty) {
|
||||
set("eval_metric", setupDefaultEvalMetric())
|
||||
}
|
||||
}
|
||||
|
||||
fromXGBParamMapToParams()
|
||||
|
||||
private[spark] def fromParamsToXGBParamMap: Map[String, Any] = {
|
||||
val xgbParamMap = new mutable.HashMap[String, Any]()
|
||||
for (param <- params) {
|
||||
xgbParamMap += param.name -> $(param)
|
||||
}
|
||||
val r = xgbParamMap.toMap
|
||||
if (!XGBoost.isClassificationTask(r) || $(numClasses) == 2) {
|
||||
r - "num_class"
|
||||
} else {
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
|
||||
var newTrainingSet = trainingSet
|
||||
if (!trainingSet.columns.contains($(baseMarginCol))) {
|
||||
newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
|
||||
}
|
||||
if (!trainingSet.columns.contains($(weightCol))) {
|
||||
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
|
||||
}
|
||||
newTrainingSet
|
||||
}
|
||||
|
||||
/**
|
||||
* produce a XGBoostModel by fitting the given dataset
|
||||
*/
|
||||
override def train(trainingSet: Dataset[_]): XGBoostModel = {
|
||||
val instances = ensureColumns(trainingSet).select(
|
||||
col($(featuresCol)),
|
||||
col($(labelCol)).cast(FloatType),
|
||||
col($(baseMarginCol)).cast(FloatType),
|
||||
col($(weightCol)).cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
|
||||
}
|
||||
transformSchema(trainingSet.schema, logging = true)
|
||||
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
|
||||
val trainedModel = XGBoost.trainDistributed(instances, derivedXGBoosterParamMap,
|
||||
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing)).setParent(this)
|
||||
val returnedModel = copyValues(trainedModel, extractParamMap())
|
||||
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
|
||||
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
|
||||
}
|
||||
returnedModel
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostEstimator = {
|
||||
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
|
||||
}
|
||||
|
||||
override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostEstimator extends MLReadable[XGBoostEstimator] {
|
||||
|
||||
override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader
|
||||
|
||||
override def load(path: String): XGBoostEstimator = super.load(path)
|
||||
|
||||
private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator)
|
||||
extends MLWriter {
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
require(instance.fromParamsToXGBParamMap("custom_eval") == null &&
|
||||
instance.fromParamsToXGBParamMap("custom_obj") == null,
|
||||
"we do not support persist XGBoostEstimator with customized evaluator and objective" +
|
||||
" function for now")
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] {
|
||||
|
||||
override def load(path: String): XGBoostEstimator = {
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
|
||||
val cls = Utils.classForName(metadata.className)
|
||||
val instance =
|
||||
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
|
||||
DefaultXGBoostParamsReader.getAndSetParams(instance, metadata)
|
||||
instance.asInstanceOf[XGBoostEstimator]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,387 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||
|
||||
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||
|
||||
import org.apache.spark.ml.PredictionModel
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||
import org.apache.spark.ml.param.{BooleanParam, ParamMap, Params}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
/**
|
||||
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||
*/
|
||||
abstract class XGBoostModel(protected var _booster: Booster)
|
||||
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
|
||||
with Params with MLWritable {
|
||||
|
||||
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
||||
|
||||
/**
|
||||
* Returns summary (e.g. train/test objective history) of model on the
|
||||
* training set. An exception is thrown if no summary is available.
|
||||
*/
|
||||
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
||||
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
||||
}
|
||||
|
||||
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
||||
trainingSummary = Some(summary)
|
||||
this
|
||||
}
|
||||
|
||||
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||
|
||||
// scalastyle:off
|
||||
|
||||
final val useExternalMemory = new BooleanParam(this, "use_external_memory",
|
||||
"whether to use external memory for prediction")
|
||||
|
||||
setDefault(useExternalMemory, false)
|
||||
|
||||
def setExternalMemory(value: Boolean): XGBoostModel = set(useExternalMemory, value)
|
||||
|
||||
// scalastyle:on
|
||||
|
||||
/**
|
||||
* Predict leaf instances with the given test set (represented as RDD)
|
||||
*
|
||||
* @param testSet test set represented as RDD
|
||||
*/
|
||||
def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Float]] = {
|
||||
import DataUtils._
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
if (testSamples.nonEmpty) {
|
||||
val dMatrix = new DMatrix(testSamples.map(_.asXGB))
|
||||
try {
|
||||
broadcastBooster.value.predictLeaf(dMatrix).iterator
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||
*
|
||||
* NOTE: you have to specify value of either eval or iter; when you specify both, this method
|
||||
* adopts the default eval metric of model
|
||||
*
|
||||
* @param evalDataset the dataset used for evaluation
|
||||
* @param evalName the name of evaluation
|
||||
* @param evalFunc the customized evaluation function, null by default to use the default metric
|
||||
* of model
|
||||
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
||||
* @param groupData group data specify each group size for ranking task. Top level corresponds
|
||||
* to partition id, second level is the group sizes.
|
||||
* @return the average metric over all partitions
|
||||
*/
|
||||
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||
iter: Int = -1, useExternalCache: Boolean = false,
|
||||
groupData: Seq[Seq[Int]] = null): String = {
|
||||
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||
val appName = evalDataset.context.appName
|
||||
val allEvalMetrics = evalDataset.mapPartitions {
|
||||
labeledPointsPartition =>
|
||||
import DataUtils._
|
||||
if (labeledPointsPartition.hasNext) {
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName = {
|
||||
if (broadcastUseExternalCache.value) {
|
||||
s"$appName-${TaskContext.get().stageId()}-$evalName" +
|
||||
s"-deval_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(labeledPointsPartition.map(_.asXGB), cacheFileName)
|
||||
try {
|
||||
if (groupData != null) {
|
||||
dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
(evalFunc, iter) match {
|
||||
case (null, _) => {
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
val Array(evName, predNumeric) = predStr.split(":")
|
||||
Iterator(Some(evName, predNumeric.toFloat))
|
||||
}
|
||||
case _ => {
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
}.filter(_.isDefined).collect()
|
||||
val evalPrefix = allEvalMetrics.map(_.get._1).head
|
||||
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
|
||||
s"$evalPrefix = $evalMetricMean"
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict result with the given test set (represented as RDD)
|
||||
*
|
||||
* @param testSet test set represented as RDD
|
||||
* @param missingValue the specified value to represent the missing value
|
||||
*/
|
||||
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val sampleArray = testSamples.toArray
|
||||
val numRows = sampleArray.length
|
||||
if (numRows == 0) {
|
||||
Iterator()
|
||||
} else {
|
||||
val numColumns = sampleArray.head.size
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
// translate to required format
|
||||
val flatSampleArray = new Array[Float](numRows * numColumns)
|
||||
for (i <- flatSampleArray.indices) {
|
||||
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
||||
}
|
||||
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
||||
try {
|
||||
broadcastBooster.value.predict(dMatrix).iterator
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict result with the given test set (represented as RDD)
|
||||
*
|
||||
* @param testSet test set represented as RDD
|
||||
* @param useExternalCache whether to use external cache for the test set
|
||||
* @param outputMargin whether to output raw untransformed margin value
|
||||
*/
|
||||
def predict(
|
||||
testSet: RDD[MLVector],
|
||||
useExternalCache: Boolean = false,
|
||||
outputMargin: Boolean = false): RDD[Array[Float]] = {
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
val appName = testSet.context.appName
|
||||
testSet.mapPartitions { testSamples =>
|
||||
if (testSamples.nonEmpty) {
|
||||
import DataUtils._
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName = {
|
||||
if (useExternalCache) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(testSamples.map(_.asXGB), cacheFileName)
|
||||
try {
|
||||
broadcastBooster.value.predict(dMatrix).iterator
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected def transformImpl(testSet: Dataset[_]): DataFrame
|
||||
|
||||
/**
|
||||
* append leaf index of each row as an additional column in the original dataset
|
||||
*
|
||||
* @return the original dataframe with an additional column containing prediction results
|
||||
*/
|
||||
def transformLeaf(testSet: Dataset[_]): DataFrame = {
|
||||
val predictRDD = produceRowRDD(testSet, predLeaf = true)
|
||||
setPredictionCol("predLeaf")
|
||||
transformSchema(testSet.schema, logging = true)
|
||||
testSet.sparkSession.createDataFrame(predictRDD, testSet.schema.add($(predictionCol),
|
||||
ArrayType(FloatType, containsNull = false)))
|
||||
}
|
||||
|
||||
protected def produceRowRDD(testSet: Dataset[_], outputMargin: Boolean = false,
|
||||
predLeaf: Boolean = false): RDD[Row] = {
|
||||
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = testSet.sparkSession.sparkContext.appName
|
||||
testSet.rdd.mapPartitions {
|
||||
rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[MLVector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cachePrefix = {
|
||||
if ($(useExternalMemory)) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val testDataset = new DMatrix(vectorIterator.map(_.asXGB), cachePrefix)
|
||||
try {
|
||||
val rawPredictResults = {
|
||||
if (!predLeaf) {
|
||||
broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator
|
||||
} else {
|
||||
broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator
|
||||
}
|
||||
}
|
||||
Rabit.shutdown()
|
||||
// concatenate original data partition and predictions
|
||||
rowItr1.zip(rawPredictResults).map {
|
||||
case (originalColumns: Row, predictColumn: Row) =>
|
||||
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
|
||||
}
|
||||
} finally {
|
||||
testDataset.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator[Row]()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* produces the prediction results and append as an additional column in the original dataset
|
||||
* NOTE: the prediction results is kept as the original format of xgboost
|
||||
*
|
||||
* @return the original dataframe with an additional column containing prediction results
|
||||
*/
|
||||
override def transform(testSet: Dataset[_]): DataFrame = {
|
||||
transformImpl(testSet)
|
||||
}
|
||||
|
||||
private def saveGeneralModelParam(outputStream: FSDataOutputStream): Unit = {
|
||||
outputStream.writeUTF(getFeaturesCol)
|
||||
outputStream.writeUTF(getLabelCol)
|
||||
outputStream.writeUTF(getPredictionCol)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as to HDFS-compatible file system.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = {
|
||||
val path = new Path(modelPath)
|
||||
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
|
||||
// output model type
|
||||
this match {
|
||||
case model: XGBoostClassificationModel =>
|
||||
outputStream.writeUTF("_cls_")
|
||||
saveGeneralModelParam(outputStream)
|
||||
outputStream.writeUTF(model.getRawPredictionCol)
|
||||
outputStream.writeInt(model.numClasses)
|
||||
// threshold
|
||||
// threshold length
|
||||
if (!isDefined(model.thresholds)) {
|
||||
outputStream.writeInt(-1)
|
||||
} else {
|
||||
val thresholdLength = model.getThresholds.length
|
||||
outputStream.writeInt(thresholdLength)
|
||||
for (i <- 0 until thresholdLength) {
|
||||
outputStream.writeDouble(model.getThresholds(i))
|
||||
}
|
||||
}
|
||||
case model: XGBoostRegressionModel =>
|
||||
outputStream.writeUTF("_reg_")
|
||||
// eventual prediction col
|
||||
saveGeneralModelParam(outputStream)
|
||||
}
|
||||
// booster
|
||||
_booster.saveModel(outputStream)
|
||||
outputStream.close()
|
||||
}
|
||||
|
||||
def booster: Booster = _booster
|
||||
|
||||
def version: Int = this.booster.booster.getVersion
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
|
||||
|
||||
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostModel extends MLReadable[XGBoostModel] {
|
||||
private[spark] def apply(booster: Booster, isClassification: Boolean): XGBoostModel = {
|
||||
if (!isClassification) {
|
||||
new XGBoostRegressionModel(booster)
|
||||
} else {
|
||||
new XGBoostClassificationModel(booster)
|
||||
}
|
||||
}
|
||||
|
||||
override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader
|
||||
|
||||
override def load(path: String): XGBoostModel = super.load(path)
|
||||
|
||||
private[XGBoostModel] class XGBoostModelModelWriter(instance: XGBoostModel) extends MLWriter {
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
val dataPath = new Path(path, "data").toString
|
||||
instance.saveModelAsHadoopFile(dataPath)
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
|
||||
|
||||
override def load(path: String): XGBoostModel = {
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
val dataPath = new Path(path, "data").toString
|
||||
// not used / all data resides in platform independent xgboost model file
|
||||
// val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
||||
XGBoost.loadModelFromHadoopFile(dataPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.ml.linalg.{Vector => MLVector}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||
|
||||
/**
|
||||
* class of XGBoost model used for regression task
|
||||
*/
|
||||
class XGBoostRegressionModel private[spark](override val uid: String, booster: Booster)
|
||||
extends XGBoostModel(booster) {
|
||||
|
||||
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostRegressionModel"), _booster)
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
|
||||
transformSchema(testSet.schema, logging = true)
|
||||
val predictRDD = produceRowRDD(testSet)
|
||||
val tempPredColName = $(predictionCol) + "_temp"
|
||||
val transformerForArrayTypedPredCol =
|
||||
udf((regressionResults: mutable.WrappedArray[Float]) => regressionResults(0))
|
||||
testSet.sparkSession.createDataFrame(predictRDD,
|
||||
schema = testSet.schema.add(tempPredColName, ArrayType(FloatType, containsNull = false))
|
||||
).withColumn(
|
||||
$(predictionCol),
|
||||
transformerForArrayTypedPredCol.apply(col(tempPredColName))).drop(tempPredColName)
|
||||
}
|
||||
|
||||
override protected def predict(features: MLVector): Double = {
|
||||
throw new Exception("XGBoost does not support online prediction for now")
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
||||
val newModel = copyValues(new XGBoostRegressionModel(booster), extra)
|
||||
newModel.setSummary(summary)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
|
||||
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
|
||||
with ParamMapFuncs
|
||||
|
||||
class XGBoostRegressor (
|
||||
override val uid: String,
|
||||
private val xgboostParams: Map[String, Any])
|
||||
extends Predictor[Vector, XGBoostRegressor, XGBoostRegressionModel]
|
||||
with XGBoostRegressorParams with DefaultParamsWritable {
|
||||
|
||||
def this() = this(Identifiable.randomUID("xgbr"), Map[String, Any]())
|
||||
|
||||
def this(uid: String) = this(uid, Map[String, Any]())
|
||||
|
||||
def this(xgboostParams: Map[String, Any]) = this(
|
||||
Identifiable.randomUID("xgbr"), xgboostParams)
|
||||
|
||||
XGBoostToMLlibParams(xgboostParams)
|
||||
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
|
||||
|
||||
def setGroupCol(value: String): this.type = set(groupCol, value)
|
||||
|
||||
// setters for general params
|
||||
def setNumRound(value: Int): this.type = set(numRound, value)
|
||||
|
||||
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
|
||||
|
||||
def setNthread(value: Int): this.type = set(nthread, value)
|
||||
|
||||
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
|
||||
|
||||
def setSilent(value: Int): this.type = set(silent, value)
|
||||
|
||||
def setMissing(value: Float): this.type = set(missing, value)
|
||||
|
||||
def setTimeoutRequestWorkers(value: Long): this.type = set(timeoutRequestWorkers, value)
|
||||
|
||||
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
|
||||
|
||||
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
|
||||
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
// setters for booster params
|
||||
def setBooster(value: String): this.type = set(booster, value)
|
||||
|
||||
def setEta(value: Double): this.type = set(eta, value)
|
||||
|
||||
def setGamma(value: Double): this.type = set(gamma, value)
|
||||
|
||||
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
|
||||
|
||||
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
|
||||
|
||||
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
|
||||
|
||||
def setSubsample(value: Double): this.type = set(subsample, value)
|
||||
|
||||
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
|
||||
|
||||
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
|
||||
|
||||
def setLambda(value: Double): this.type = set(lambda, value)
|
||||
|
||||
def setAlpha(value: Double): this.type = set(alpha, value)
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
|
||||
def setSketchEps(value: Double): this.type = set(sketchEps, value)
|
||||
|
||||
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
|
||||
|
||||
def setSampleType(value: String): this.type = set(sampleType, value)
|
||||
|
||||
def setNormalizeType(value: String): this.type = set(normalizeType, value)
|
||||
|
||||
def setRateDrop(value: Double): this.type = set(rateDrop, value)
|
||||
|
||||
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
|
||||
|
||||
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
|
||||
|
||||
// setters for learning params
|
||||
def setObjective(value: String): this.type = set(objective, value)
|
||||
|
||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
||||
|
||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
||||
|
||||
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
|
||||
|
||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
if ($(objective).startsWith("rank")) {
|
||||
"map"
|
||||
} else {
|
||||
"rmse"
|
||||
}
|
||||
}
|
||||
|
||||
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
|
||||
|
||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
||||
set(evalMetric, setupDefaultEvalMetric())
|
||||
}
|
||||
|
||||
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
||||
lit(Float.NaN)
|
||||
} else {
|
||||
col($(baseMarginCol))
|
||||
}
|
||||
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
|
||||
|
||||
val instances: RDD[XGBLabeledPoint] = dataset.select(
|
||||
col($(labelCol)).cast(FloatType),
|
||||
col($(featuresCol)),
|
||||
weight.cast(FloatType),
|
||||
group.cast(IntegerType),
|
||||
baseMargin.cast(FloatType)
|
||||
).rdd.map {
|
||||
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing))
|
||||
val model = new XGBoostRegressionModel(uid, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
model
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostRegressor = defaultCopy(extra)
|
||||
}
|
||||
|
||||
object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
|
||||
|
||||
override def load(path: String): XGBoostRegressor = super.load(path)
|
||||
}
|
||||
|
||||
class XGBoostRegressionModel private[ml] (
|
||||
override val uid: String,
|
||||
private[spark] val _booster: Booster)
|
||||
extends PredictionModel[Vector, XGBoostRegressionModel]
|
||||
with XGBoostRegressorParams with MLWritable with Serializable {
|
||||
|
||||
import XGBoostRegressionModel._
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
||||
|
||||
/**
|
||||
* Returns summary (e.g. train/test objective history) of model on the
|
||||
* training set. An exception is thrown if no summary is available.
|
||||
*/
|
||||
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
||||
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
||||
}
|
||||
|
||||
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
||||
trainingSummary = Some(summary)
|
||||
this
|
||||
}
|
||||
|
||||
override def predict(features: Vector): Double = {
|
||||
throw new Exception("XGBoost-Spark does not support online prediction")
|
||||
}
|
||||
|
||||
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
||||
|
||||
val schema = StructType(dataset.schema.fields ++
|
||||
Seq(StructField(name = _originalPredictionCol, dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||
|
||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = dataset.sparkSession.sparkContext.appName
|
||||
|
||||
val rdd = dataset.rdd.mapPartitions { rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cacheInfo = {
|
||||
if ($(useExternalMemory)) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
|
||||
try {
|
||||
val originalPredictionItr = {
|
||||
bBooster.value.predict(dm).map(Row(_)).iterator
|
||||
}
|
||||
Rabit.shutdown()
|
||||
rowItr1.zip(originalPredictionItr).map {
|
||||
case (originals: Row, originalPrediction: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
|
||||
}
|
||||
} finally {
|
||||
dm.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator[Row]()
|
||||
}
|
||||
}
|
||||
|
||||
bBooster.unpersist(blocking = false)
|
||||
|
||||
dataset.sparkSession.createDataFrame(rdd, schema)
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
|
||||
// Output selected columns only.
|
||||
// This is a bit complicated since it tries to avoid repeated computation.
|
||||
var outputData = transformInternal(dataset)
|
||||
var numColsOutput = 0
|
||||
|
||||
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
|
||||
originalPrediction(0).toDouble
|
||||
}
|
||||
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn($(predictionCol), predictUDF(col(_originalPredictionCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if (numColsOutput == 0) {
|
||||
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
||||
" since no output columns were set.")
|
||||
}
|
||||
outputData.toDF.drop(col(_originalPredictionCol))
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
||||
val newModel = copyValues(new XGBoostRegressionModel(uid, _booster), extra)
|
||||
newModel.setSummary(summary).setParent(parent)
|
||||
}
|
||||
|
||||
override def write: MLWriter =
|
||||
new XGBoostRegressionModel.XGBoostRegressionModelWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
||||
|
||||
private val _originalPredictionCol = "_originalPrediction"
|
||||
|
||||
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
|
||||
|
||||
override def load(path: String): XGBoostRegressionModel = super.load(path)
|
||||
|
||||
private[XGBoostRegressionModel]
|
||||
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// Save metadata and Params
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
// Save model data
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||
instance._booster.saveModel(outputStream)
|
||||
outputStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostRegressionModelReader extends MLReader[XGBoostRegressionModel] {
|
||||
|
||||
/** Checked against metadata when loading model */
|
||||
private val className = classOf[XGBoostRegressionModel].getName
|
||||
|
||||
override def load(path: String): XGBoostRegressionModel = {
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
||||
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
||||
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
||||
|
||||
val booster = SXGBoost.loadModel(dataInStream)
|
||||
val model = new XGBoostRegressionModel(metadata.uid, booster)
|
||||
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
||||
model
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -20,40 +20,48 @@ import scala.collection.immutable.HashSet
|
||||
|
||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||
|
||||
trait BoosterParams extends Params {
|
||||
private[spark] trait BoosterParams extends Params {
|
||||
|
||||
/**
|
||||
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
|
||||
*/
|
||||
val boosterType = new Param[String](this, "booster",
|
||||
final val booster = new Param[String](this, "booster",
|
||||
s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}",
|
||||
(value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase))
|
||||
|
||||
final def getBooster: String = $(booster)
|
||||
|
||||
/**
|
||||
* step size shrinkage used in update to prevents overfitting. After each boosting step, we
|
||||
* can directly get the weights of new features and eta actually shrinks the feature weights
|
||||
* to make the boosting process more conservative. [default=0.3] range: [0,1]
|
||||
*/
|
||||
val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
|
||||
final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
|
||||
" overfitting. After each boosting step, we can directly get the weights of new features." +
|
||||
" and eta actually shrinks the feature weights to make the boosting process more conservative.",
|
||||
(value: Double) => value >= 0 && value <= 1)
|
||||
|
||||
final def getEta: Double = $(eta)
|
||||
|
||||
/**
|
||||
* minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||
* the larger, the more conservative the algorithm will be. [default=0] range: [0,
|
||||
* Double.MaxValue]
|
||||
*/
|
||||
val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a further" +
|
||||
" partition on a leaf node of the tree. the larger, the more conservative the algorithm" +
|
||||
" will be.", (value: Double) => value >= 0)
|
||||
final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " +
|
||||
"further partition on a leaf node of the tree. the larger, the more conservative the " +
|
||||
"algorithm will be.", (value: Double) => value >= 0)
|
||||
|
||||
final def getGamma: Double = $(gamma)
|
||||
|
||||
/**
|
||||
* maximum depth of a tree, increase this value will make model more complex / likely to be
|
||||
* overfitting. [default=6] range: [1, Int.MaxValue]
|
||||
*/
|
||||
val maxDepth = new IntParam(this, "max_depth", "maximum depth of a tree, increase this value" +
|
||||
" will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
|
||||
final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
|
||||
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
|
||||
|
||||
final def getMaxDepth: Int = $(maxDepth)
|
||||
|
||||
/**
|
||||
* minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
|
||||
@@ -62,13 +70,15 @@ trait BoosterParams extends Params {
|
||||
* to minimum number of instances needed to be in each node. The larger, the more conservative
|
||||
* the algorithm will be. [default=1] range: [0, Double.MaxValue]
|
||||
*/
|
||||
val minChildWeight = new DoubleParam(this, "min_child_weight", "minimum sum of instance" +
|
||||
final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" +
|
||||
" weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
|
||||
" the sum of instance weight less than min_child_weight, then the building process will" +
|
||||
" give up further partitioning. In linear regression mode, this simply corresponds to minimum" +
|
||||
" number of instances needed to be in each node. The larger, the more conservative" +
|
||||
" the algorithm will be.", (value: Double) => value >= 0)
|
||||
|
||||
final def getMinChildWeight: Double = $(minChildWeight)
|
||||
|
||||
/**
|
||||
* Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it
|
||||
* means there is no constraint. If it is set to a positive value, it can help making the update
|
||||
@@ -76,90 +86,113 @@ trait BoosterParams extends Params {
|
||||
* regression when class is extremely imbalanced. Set it to value of 1-10 might help control the
|
||||
* update. [default=0] range: [0, Double.MaxValue]
|
||||
*/
|
||||
val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow each" +
|
||||
" tree's weight" +
|
||||
final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " +
|
||||
"each tree's weight" +
|
||||
" estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
|
||||
" to a positive value, it can help making the update step more conservative. Usually this" +
|
||||
" parameter is not needed, but it might help in logistic regression when class is extremely" +
|
||||
" imbalanced. Set it to value of 1-10 might help control the update",
|
||||
(value: Double) => value >= 0)
|
||||
|
||||
final def getMaxDeltaStep: Double = $(maxDeltaStep)
|
||||
|
||||
/**
|
||||
* subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly
|
||||
* collected half of the data instances to grow trees and this will prevent overfitting.
|
||||
* [default=1] range:(0,1]
|
||||
*/
|
||||
val subSample = new DoubleParam(this, "subsample", "subsample ratio of the training instance." +
|
||||
" Setting it to 0.5 means that XGBoost randomly collected half of the data instances to" +
|
||||
" grow trees and this will prevent overfitting.", (value: Double) => value <= 1 && value > 0)
|
||||
final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " +
|
||||
"instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " +
|
||||
"instances to grow trees and this will prevent overfitting.",
|
||||
(value: Double) => value <= 1 && value > 0)
|
||||
|
||||
final def getSubsample: Double = $(subsample)
|
||||
|
||||
/**
|
||||
* subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
|
||||
*/
|
||||
val colSampleByTree = new DoubleParam(this, "colsample_bytree", "subsample ratio of columns" +
|
||||
" when constructing each tree.", (value: Double) => value <= 1 && value > 0)
|
||||
final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " +
|
||||
"columns when constructing each tree.", (value: Double) => value <= 1 && value > 0)
|
||||
|
||||
final def getColsampleBytree: Double = $(colsampleBytree)
|
||||
|
||||
/**
|
||||
* subsample ratio of columns for each split, in each level. [default=1] range: (0,1]
|
||||
*/
|
||||
val colSampleByLevel = new DoubleParam(this, "colsample_bylevel", "subsample ratio of columns" +
|
||||
" for each split, in each level.", (value: Double) => value <= 1 && value > 0)
|
||||
final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " +
|
||||
"columns for each split, in each level.", (value: Double) => value <= 1 && value > 0)
|
||||
|
||||
final def getColsampleBylevel: Double = $(colsampleBylevel)
|
||||
|
||||
/**
|
||||
* L2 regularization term on weights, increase this value will make model more conservative.
|
||||
* [default=1]
|
||||
*/
|
||||
val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, increase this" +
|
||||
" value will make model more conservative.", (value: Double) => value >= 0)
|
||||
final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " +
|
||||
"increase this value will make model more conservative.", (value: Double) => value >= 0)
|
||||
|
||||
final def getLambda: Double = $(lambda)
|
||||
|
||||
/**
|
||||
* L1 regularization term on weights, increase this value will make model more conservative.
|
||||
* [default=0]
|
||||
*/
|
||||
val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase this" +
|
||||
" value will make model more conservative.", (value: Double) => value >= 0)
|
||||
final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " +
|
||||
"this value will make model more conservative.", (value: Double) => value >= 0)
|
||||
|
||||
final def getAlpha: Double = $(alpha)
|
||||
|
||||
/**
|
||||
* The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'}
|
||||
* [default='auto']
|
||||
*/
|
||||
val treeMethod = new Param[String](this, "tree_method",
|
||||
final val treeMethod = new Param[String](this, "treeMethod",
|
||||
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
|
||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||
|
||||
final def getTreeMethod: String = $(treeMethod)
|
||||
|
||||
/**
|
||||
* growth policy for fast histogram algorithm
|
||||
*/
|
||||
val growthPolicty = new Param[String](this, "grow_policy",
|
||||
final val growPolicy = new Param[String](this, "growPolicy",
|
||||
"growth policy for fast histogram algorithm",
|
||||
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
|
||||
|
||||
final def getGrowPolicy: String = $(growPolicy)
|
||||
|
||||
/**
|
||||
* maximum number of bins in histogram
|
||||
*/
|
||||
val maxBins = new IntParam(this, "max_bin", "maximum number of bins in histogram",
|
||||
final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram",
|
||||
(value: Int) => value > 0)
|
||||
|
||||
final def getMaxBins: Int = $(maxBins)
|
||||
|
||||
/**
|
||||
* This is only used for approximate greedy algorithm.
|
||||
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
|
||||
* number of bins, this comes with theoretical guarantee with sketch accuracy.
|
||||
* [default=0.03] range: (0, 1)
|
||||
*/
|
||||
val sketchEps = new DoubleParam(this, "sketch_eps",
|
||||
final val sketchEps = new DoubleParam(this, "sketchEps",
|
||||
"This is only used for approximate greedy algorithm. This roughly translated into" +
|
||||
" O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" +
|
||||
" with theoretical guarantee with sketch accuracy.",
|
||||
(value: Double) => value < 1 && value > 0)
|
||||
|
||||
final def getSketchEps: Double = $(sketchEps)
|
||||
|
||||
/**
|
||||
* Control the balance of positive and negative weights, useful for unbalanced classes. A typical
|
||||
* value to consider: sum(negative cases) / sum(positive cases). [default=1]
|
||||
*/
|
||||
val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of positive" +
|
||||
" and negative weights, useful for unbalanced classes. A typical value to consider:" +
|
||||
final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " +
|
||||
"positive and negative weights, useful for unbalanced classes. A typical value to consider:" +
|
||||
" sum(negative cases) / sum(positive cases)")
|
||||
|
||||
final def getScalePosWeight: Double = $(scalePosWeight)
|
||||
|
||||
// Dart boosters
|
||||
|
||||
/**
|
||||
@@ -167,72 +200,59 @@ trait BoosterParams extends Params {
|
||||
* Type of sampling algorithm. "uniform": dropped trees are selected uniformly.
|
||||
* "weighted": dropped trees are selected in proportion to weight. [default="uniform"]
|
||||
*/
|
||||
val sampleType = new Param[String](this, "sample_type", "type of sampling algorithm, options:" +
|
||||
" {'uniform', 'weighted'}",
|
||||
final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " +
|
||||
"options: {'uniform', 'weighted'}",
|
||||
(value: String) => BoosterParams.supportedSampleType.contains(value))
|
||||
|
||||
final def getSampleType: String = $(sampleType)
|
||||
|
||||
/**
|
||||
* Parameter of Dart booster.
|
||||
* type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"]
|
||||
*/
|
||||
val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
|
||||
final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" +
|
||||
" algorithm, options: {'tree', 'forest'}",
|
||||
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
|
||||
|
||||
final def getNormalizeType: String = $(normalizeType)
|
||||
|
||||
/**
|
||||
* Parameter of Dart booster.
|
||||
* dropout rate. [default=0.0] range: [0.0, 1.0]
|
||||
*/
|
||||
val rateDrop = new DoubleParam(this, "rate_drop", "dropout rate", (value: Double) =>
|
||||
final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) =>
|
||||
value >= 0 && value <= 1)
|
||||
|
||||
final def getRateDrop: Double = $(rateDrop)
|
||||
|
||||
/**
|
||||
* Parameter of Dart booster.
|
||||
* probability of skip dropout. If a dropout is skipped, new trees are added in the same manner
|
||||
* as gbtree. [default=0.0] range: [0.0, 1.0]
|
||||
*/
|
||||
val skipDrop = new DoubleParam(this, "skip_drop", "probability of skip dropout. If" +
|
||||
final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" +
|
||||
" a dropout is skipped, new trees are added in the same manner as gbtree.",
|
||||
(value: Double) => value >= 0 && value <= 1)
|
||||
|
||||
final def getSkipDrop: Double = $(skipDrop)
|
||||
|
||||
// linear booster
|
||||
/**
|
||||
* Parameter of linear booster
|
||||
* L2 regularization term on bias, default 0(no L1 reg on bias because it is not important)
|
||||
*/
|
||||
val lambdaBias = new DoubleParam(this, "lambda_bias", "L2 regularization term on bias, default" +
|
||||
" 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
|
||||
final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " +
|
||||
"default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
|
||||
|
||||
setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
final def getLambdaBias: Double = $(lambdaBias)
|
||||
|
||||
setDefault(booster -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||
growthPolicty -> "depthwise", maxBins -> 16,
|
||||
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
|
||||
growPolicy -> "depthwise", maxBins -> 16,
|
||||
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
|
||||
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
||||
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
|
||||
|
||||
/**
|
||||
* Explains all params of this instance. See `explainParam()`.
|
||||
*/
|
||||
override def explainParams(): String = {
|
||||
// TODO: filter some parameters according to the booster type
|
||||
val boosterTypeStr = $(boosterType)
|
||||
val validParamList = {
|
||||
if (boosterTypeStr == "gblinear") {
|
||||
// gblinear
|
||||
params.filter(param => param.name == "lambda" ||
|
||||
param.name == "alpha" || param.name == "lambda_bias")
|
||||
} else if (boosterTypeStr != "dart") {
|
||||
// gbtree
|
||||
params.filter(param => param.name != "sample_type" &&
|
||||
param.name != "normalize_type" && param.name != "rate_drop" && param.name != "skip_drop")
|
||||
} else {
|
||||
// dart
|
||||
params.filter(_.name != "lambda_bias")
|
||||
}
|
||||
}
|
||||
explainParam(boosterType) + "\n" ++ validParamList.map(explainParam).mkString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object BoosterParams {
|
||||
|
||||
@@ -16,84 +16,104 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import com.google.common.base.CaseFormat
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
|
||||
import org.apache.spark.ml.param._
|
||||
import scala.collection.mutable
|
||||
|
||||
trait GeneralParams extends Params {
|
||||
private[spark] trait GeneralParams extends Params {
|
||||
|
||||
/**
|
||||
* The number of rounds for boosting
|
||||
*/
|
||||
val round = new IntParam(this, "num_round", "The number of rounds for boosting",
|
||||
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
final def getNumRound: Int = $(numRound)
|
||||
|
||||
/**
|
||||
* number of workers used to train xgboost model. default: 1
|
||||
*/
|
||||
val nWorkers = new IntParam(this, "nworkers", "number of workers used to run xgboost",
|
||||
final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
final def getNumWorkers: Int = $(numWorkers)
|
||||
|
||||
/**
|
||||
* number of threads used by per worker. default 1
|
||||
*/
|
||||
val numThreadPerTask = new IntParam(this, "nthread", "number of threads used by per worker",
|
||||
final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
final def getNthread: Int = $(nthread)
|
||||
|
||||
/**
|
||||
* whether to use external memory as cache. default: false
|
||||
*/
|
||||
val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external" +
|
||||
"memory as cache")
|
||||
final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
|
||||
"whether to use external memory as cache")
|
||||
|
||||
final def getUseExternalMemory: Boolean = $(useExternalMemory)
|
||||
|
||||
/**
|
||||
* 0 means printing running messages, 1 means silent mode. default: 0
|
||||
*/
|
||||
val silent = new IntParam(this, "silent",
|
||||
final val silent = new IntParam(this, "silent",
|
||||
"0 means printing running messages, 1 means silent mode.",
|
||||
(value: Int) => value >= 0 && value <= 1)
|
||||
|
||||
final def getSilent: Int = $(silent)
|
||||
|
||||
/**
|
||||
* customized objective function provided by user. default: null
|
||||
*/
|
||||
val customObj = new CustomObjParam(this, "custom_obj", "customized objective function " +
|
||||
final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
|
||||
"provided by user")
|
||||
|
||||
/**
|
||||
* customized evaluation function provided by user. default: null
|
||||
*/
|
||||
val customEval = new CustomEvalParam(this, "custom_eval", "customized evaluation function " +
|
||||
"provided by user")
|
||||
final val customEval = new CustomEvalParam(this, "customEval",
|
||||
"customized evaluation function provided by user")
|
||||
|
||||
/**
|
||||
* the value treated as missing. default: Float.NaN
|
||||
*/
|
||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
final val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
|
||||
final def getMissing: Float = $(missing)
|
||||
|
||||
/**
|
||||
* the maximum time to wait for the job requesting new workers. default: 30 minutes
|
||||
*/
|
||||
val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" +
|
||||
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
|
||||
" value is set smaller than or equal to 0.")
|
||||
final val timeoutRequestWorkers = new LongParam(this, "timeoutRequestWorkers", "the maximum " +
|
||||
"time to request new Workers if numCores are insufficient. The timeout will be disabled " +
|
||||
"if this value is set smaller than or equal to 0.")
|
||||
|
||||
final def getTimeoutRequestWorkers: Long = $(timeoutRequestWorkers)
|
||||
|
||||
/**
|
||||
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
|
||||
*/
|
||||
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
|
||||
"save checkpoints. If there are existing checkpoints in checkpoint_path. The job will load " +
|
||||
"the checkpoint with highest version as the starting point for training. If " +
|
||||
final val checkpointPath = new Param[String](this, "checkpointPath", "the hdfs folder to load " +
|
||||
"and save checkpoints. If there are existing checkpoints in checkpoint_path. The job will " +
|
||||
"load the checkpoint with highest version as the starting point for training. If " +
|
||||
"checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
|
||||
|
||||
final def getCheckpointPath: String = $(checkpointPath)
|
||||
|
||||
/**
|
||||
* Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
|
||||
* the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
|
||||
* also be set if the checkpoint interval is greater than 0.
|
||||
*/
|
||||
val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint " +
|
||||
"interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained model will get " +
|
||||
"checkpointed every 10 iterations. Note: `checkpoint_path` must also be set if the checkpoint" +
|
||||
" interval is greater than 0.", (interval: Int) => interval == -1 || interval >= 1)
|
||||
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval",
|
||||
"set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained " +
|
||||
"model will get checkpointed every 10 iterations. Note: `checkpoint_path` must also be " +
|
||||
"set if the checkpoint interval is greater than 0.",
|
||||
(interval: Int) => interval == -1 || interval >= 1)
|
||||
|
||||
final def getCheckpointInterval: Int = $(checkpointInterval)
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
@@ -122,15 +142,87 @@ trait GeneralParams extends Params {
|
||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||
* Ignored if the tracker implementation is "python".
|
||||
*/
|
||||
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
|
||||
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
|
||||
|
||||
/** Random seed for the C++ part of XGBoost and train/test splitting. */
|
||||
val seed = new LongParam(this, "seed", "random seed")
|
||||
final val seed = new LongParam(this, "seed", "random seed")
|
||||
|
||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||
final def getSeed: Long = $(seed)
|
||||
|
||||
setDefault(numRound -> 1, numWorkers -> 1, nthread -> 1,
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
||||
checkpointPath -> "", checkpointInterval -> -1
|
||||
)
|
||||
}
|
||||
|
||||
trait HasBaseMarginCol extends Params {
|
||||
|
||||
/**
|
||||
* Param for initial prediction (aka base margin) column name.
|
||||
* @group param
|
||||
*/
|
||||
final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
|
||||
"Initial prediction (aka base margin) column name.")
|
||||
|
||||
/** @group getParam */
|
||||
final def getBaseMarginCol: String = $(baseMarginCol)
|
||||
}
|
||||
|
||||
trait HasGroupCol extends Params {
|
||||
|
||||
/**
|
||||
* Param for group column name.
|
||||
* @group param
|
||||
*/
|
||||
final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
|
||||
|
||||
/** @group getParam */
|
||||
final def getGroupCol: String = $(groupCol)
|
||||
|
||||
}
|
||||
|
||||
trait HasNumClass extends Params {
|
||||
|
||||
/**
|
||||
* number of classes
|
||||
*/
|
||||
final val numClass = new IntParam(this, "numClass", "number of classes")
|
||||
|
||||
/** @group getParam */
|
||||
final def getNumClass: Int = $(numClass)
|
||||
}
|
||||
|
||||
private[spark] trait ParamMapFuncs extends Params {
|
||||
|
||||
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
params.find(_.name == name) match {
|
||||
case None =>
|
||||
case Some(_: DoubleParam) =>
|
||||
set(name, paramValue.toString.toDouble)
|
||||
case Some(_: BooleanParam) =>
|
||||
set(name, paramValue.toString.toBoolean)
|
||||
case Some(_: IntParam) =>
|
||||
set(name, paramValue.toString.toInt)
|
||||
case Some(_: FloatParam) =>
|
||||
set(name, paramValue.toString.toFloat)
|
||||
case Some(_: Param[_]) =>
|
||||
set(name, paramValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def MLlib2XGBoostParams: Map[String, Any] = {
|
||||
val xgboostParams = new mutable.HashMap[String, Any]()
|
||||
for (param <- params) {
|
||||
if (isDefined(param)) {
|
||||
val name = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, param.name)
|
||||
xgboostParams += name -> $(param)
|
||||
}
|
||||
}
|
||||
xgboostParams.toMap
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,76 +20,70 @@ import scala.collection.immutable.HashSet
|
||||
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
trait LearningTaskParams extends Params {
|
||||
|
||||
/**
|
||||
* number of tasks to learn
|
||||
*/
|
||||
val numClasses = new IntParam(this, "num_class", "number of classes")
|
||||
private[spark] trait LearningTaskParams extends Params {
|
||||
|
||||
/**
|
||||
* Specify the learning task and the corresponding learning objective.
|
||||
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
|
||||
* multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:linear
|
||||
*/
|
||||
val objective = new Param[String](this, "objective", "objective function used for training," +
|
||||
s" options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||
final val objective = new Param[String](this, "objective", "objective function used for " +
|
||||
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||
(value: String) => LearningTaskParams.supportedObjective.contains(value))
|
||||
|
||||
final def getObjective: String = $(objective)
|
||||
|
||||
/**
|
||||
* the initial prediction score of all instances, global bias. default=0.5
|
||||
*/
|
||||
val baseScore = new DoubleParam(this, "base_score", "the initial prediction score of all" +
|
||||
final val baseScore = new DoubleParam(this, "baseScore", "the initial prediction score of all" +
|
||||
" instances, global bias")
|
||||
|
||||
final def getBaseScore: Double = $(baseScore)
|
||||
|
||||
/**
|
||||
* evaluation metrics for validation data, a default metric will be assigned according to
|
||||
* objective(rmse for regression, and error for classification, mean average precision for
|
||||
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, aucpr, ndcg, map,
|
||||
* gamma-deviance
|
||||
*/
|
||||
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
|
||||
" data, a default metric will be assigned according to objective (rmse for regression, and" +
|
||||
" error for classification, mean average precision for ranking), options: " +
|
||||
s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
|
||||
"validation data, a default metric will be assigned according to objective " +
|
||||
"(rmse for regression, and error for classification, mean average precision for ranking), " +
|
||||
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
|
||||
|
||||
final def getEvalMetric: String = $(evalMetric)
|
||||
|
||||
/**
|
||||
* group data specify each group sizes for ranking task. To correspond to partition of
|
||||
* training data, it is nested.
|
||||
*/
|
||||
val groupData = new GroupDataParam(this, "groupData", "group data specify each group size" +
|
||||
" for ranking task. To correspond to partition of training data, it is nested.")
|
||||
|
||||
/**
|
||||
* Initial prediction (aka base margin) column name.
|
||||
*/
|
||||
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
|
||||
|
||||
/**
|
||||
* Instance weights column name.
|
||||
*/
|
||||
val weightCol = new Param[String](this, "weightCol", "weight column name")
|
||||
final val groupData = new GroupDataParam(this, "groupData", "group data specify each group " +
|
||||
"size for ranking task. To correspond to partition of training data, it is nested.")
|
||||
|
||||
/**
|
||||
* Fraction of training points to use for testing.
|
||||
*/
|
||||
val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||
"fraction of training points to use for testing",
|
||||
ParamValidators.inRange(0, 1))
|
||||
|
||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
||||
|
||||
/**
|
||||
* If non-zero, the training will be stopped after a specified number
|
||||
* of consecutive increases in any evaluation metric.
|
||||
*/
|
||||
val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
|
||||
final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
|
||||
"number of rounds of decreasing eval metric to tolerate before " +
|
||||
"stopping the training",
|
||||
(value: Int) => value == 0 || value > 1)
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
||||
baseMarginCol -> "baseMargin", weightCol -> "weight", trainTestRatio -> 1.0,
|
||||
numEarlyStoppingRounds -> 0)
|
||||
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, groupData -> null,
|
||||
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
|
||||
Reference in New Issue
Block a user