[jvm-packages] XGBoost Spark integration refactor (#3387)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* [jvm-packages] XGBoost Spark integration refactor. (#3313)

* XGBoost Spark integration refactor.

* Make corresponding update for xgboost4j-example

* Address comments.

* [jvm-packages] Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib (#3326)

* Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib

* Fix extra space.

* [jvm-packages] XGBoost Spark supports ranking with group data. (#3369)

* XGBoost Spark supports ranking with group data.

* Use Iterator.duplicate to prevent OOM.

* Update CheckpointManagerSuite.scala

* Resolve conflicts
This commit is contained in:
Yanbo Liang 2018-06-18 15:39:18 -07:00 committed by Nan Zhu
parent e6696337e4
commit 2c4359e914
34 changed files with 1921 additions and 2173 deletions

View File

@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.io.Source
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoost}
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}
@ -160,10 +160,10 @@ object SparkModelTuningTool {
private def crossValidation(
xgboostParam: Map[String, Any],
trainingData: Dataset[_]): TrainValidationSplitModel = {
val xgbEstimator = new XGBoostEstimator(xgboostParam).setFeaturesCol("features").
val xgbEstimator = new XGBoostRegressor(xgboostParam).setFeaturesCol("features").
setLabelCol("logSales")
val paramGrid = new ParamGridBuilder()
.addGrid(xgbEstimator.round, Array(20, 50))
.addGrid(xgbEstimator.numRound, Array(20, 50))
.addGrid(xgbEstimator.eta, Array(0.1, 0.4))
.build()
val tv = new TrainValidationSplit()

View File

@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.example.spark
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkConf
@ -45,9 +45,10 @@ object SparkWithDataFrame {
val paramMap = List(
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
val xgboostModel = XGBoost.trainWithDataFrame(
trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true)
"objective" -> "binary:logistic",
"num_round" -> numRound,
"nWorkers" -> args(1).toInt).toMap
val xgboostModel = new XGBoostClassifier(paramMap).fit(trainDF)
// xgboost-spark appends the column containing prediction results
xgboostModel.transform(testDF).show()
}

View File

@ -1,58 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example.spark
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}
object SparkWithRDD {
def main(args: Array[String]): Unit = {
if (args.length != 5) {
println(
"usage: program num_of_rounds num_workers training_path test_path model_path")
sys.exit(1)
}
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
implicit val sc = new SparkContext(sparkConf)
val inputTrainPath = args(2)
val inputTestPath = args(3)
val outputModelPath = args(4)
// number of iterations
val numRound = args(0).toInt
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).map(lp =>
MLLabeledPoint(lp.label, new MLDenseVector(lp.features.toArray)))
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath)
.map(lp => new MLDenseVector(lp.features.toArray))
// training parameters
val paramMap = List(
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
useExternalMemory = true)
xgboostModel.predict(testSet, missingValue = Float.NaN)
// save model to HDFS path
xgboostModel.saveModelAsHadoopFile(outputModelPath)
}
}

View File

@ -17,6 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
@ -63,9 +64,9 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
val version = versions.max
val fullPath = getPath(version)
logger.info(s"Start training from previous booster at $fullPath")
val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc)
model.booster.booster.setVersion(version)
model.booster
val booster = SXGBoost.loadModel(fullPath)
booster.booster.setVersion(version)
booster
} else {
null
}
@ -76,12 +77,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
*
* @param checkpoint the checkpoint to save as an XGBoostModel
*/
private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = {
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(checkpoint.version)
logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath")
checkpoint.saveModelAsHadoopFile(fullPath)(sc)
val fullPath = getPath(checkpoint.getVersion)
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
checkpoint.saveModel(fullPath)
prevModelPaths.foreach(path => fs.delete(path, true))
}

View File

@ -21,16 +21,15 @@ import java.nio.file.Files
import scala.collection.mutable
import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
@ -134,7 +133,7 @@ object XGBoost extends Serializable {
fromBaseMarginsToArray(baseMargins), cacheDirName)
try {
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
.map(_.toString.toInt).getOrElse(0)
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.train, params, round,
@ -148,89 +147,6 @@ object XGBoost extends Serializable {
}.cache()
}
/**
* Train XGBoost model with the DataFrame-represented data
*
* @param trainingData the training set represented as DataFrame
* @param params Map containing the parameters to configure XGBoost
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing The value which represents a missing value in the dataset
* @param featureCol the name of input column, "features" as default value
* @param labelCol the name of output column, "label" as default value
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def trainWithDataFrame(
trainingData: Dataset[_],
params: Map[String, Any],
round: Int,
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN,
featureCol: String = "features",
labelCol: String = "label"): XGBoostModel = {
require(nWorkers > 0, "you must specify more than 0 workers")
val estimator = new XGBoostEstimator(params)
// assigning general parameters
estimator.
set(estimator.useExternalMemory, useExternalMemory).
set(estimator.round, round).
set(estimator.nWorkers, nWorkers).
set(estimator.customObj, obj).
set(estimator.customEval, eval).
set(estimator.missing, missing).
setFeaturesCol(featureCol).
setLabelCol(labelCol).
fit(trainingData)
}
private[spark] def isClassificationTask(params: Map[String, Any]): Boolean = {
val objective = params.getOrElse("objective", params.getOrElse("obj_type", null))
objective != null && {
val objStr = objective.toString
objStr != "regression" && !objStr.startsWith("reg:") && objStr != "count:poisson" &&
!objStr.startsWith("rank:")
}
}
/**
* Train XGBoost model with the RDD-represented data
*
* @param trainingData the training set represented as RDD
* @param params Map containing the configuration entries
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@deprecated("Use XGBoost.trainWithRDD instead.")
def train(
trainingData: RDD[MLLabeledPoint],
params: Map[String, Any],
round: Int,
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
}
private def overrideParamsAccordingToTaskCPUs(
params: Map[String, Any],
sc: SparkContext): Map[String, Any] = {
@ -259,39 +175,8 @@ object XGBoost extends Serializable {
}
/**
* Train XGBoost model with the RDD-represented data
*
* @param trainingData the training set represented as RDD
* @param params Map containing the configuration entries
* @param round the number of iterations
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
* workers equals to the partition number of trainingData RDD
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing The value which represents a missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed
* @return XGBoostModel when successful training
* @return A tuple of the booster and the metrics used to build training summary
*/
@throws(classOf[XGBoostError])
def trainWithRDD(
trainingData: RDD[MLLabeledPoint],
params: Map[String, Any],
round: Int,
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
import DataUtils._
val xgbTrainingData = trainingData.map { case MLLabeledPoint(label, features) =>
features.asXGB.copy(label = label.toFloat)
}
trainDistributed(xgbTrainingData, params, round, nWorkers, obj, eval,
useExternalMemory, missing)
}
@throws(classOf[XGBoostError])
private[spark] def trainDistributed(
trainingData: RDD[XGBLabeledPoint],
@ -301,7 +186,7 @@ object XGBoost extends Serializable {
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
if (params.contains("tree_method")) {
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
" for now")
@ -350,20 +235,15 @@ object XGBoost extends Serializable {
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
sparkJobThread, isClsTask)
if (isClsTask){
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
}
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
sparkJobThread)
if (checkpointRound < round) {
prevBooster = model.booster
checkpointManager.updateCheckpoint(model)
prevBooster = booster
checkpointManager.updateCheckpoint(prevBooster)
}
model
(booster, metrics)
} finally {
tracker.stop()
}
@ -383,17 +263,14 @@ object XGBoost extends Serializable {
private def postTrackerReturnProcessing(
trackerReturnVal: Int,
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
sparkJobThread: Thread,
isClassificationTask: Boolean
): XGBoostModel = {
sparkJobThread: Thread): (Booster, Map[String, Array[Float]]) = {
if (trackerReturnVal == 0) {
// Copies of the final booster and the corresponding metrics
// reside in each partition of the `distributedBoostersAndMetrics`.
// Any of them can be used to create the model.
val (booster, metrics) = distributedBoostersAndMetrics.first()
val xgboostModel = XGBoostModel(booster, isClassificationTask)
distributedBoostersAndMetrics.unpersist(false)
xgboostModel.setSummary(XGBoostTrainingSummary(metrics))
(booster, metrics)
} else {
try {
if (sparkJobThread.isAlive) {
@ -407,64 +284,6 @@ object XGBoost extends Serializable {
}
}
private def loadGeneralModelParams(inputStream: FSDataInputStream): (String, String, String) = {
val featureCol = inputStream.readUTF()
val labelCol = inputStream.readUTF()
val predictionCol = inputStream.readUTF()
(featureCol, labelCol, predictionCol)
}
private def setGeneralModelParams(
featureCol: String,
labelCol: String,
predCol: String,
xgBoostModel: XGBoostModel): XGBoostModel = {
xgBoostModel.setFeaturesCol(featureCol)
xgBoostModel.setLabelCol(labelCol)
xgBoostModel.setPredictionCol(predCol)
}
/**
* Load XGBoost model from path in HDFS-compatible file system
*
* @param modelPath The path of the file representing the model
* @return The loaded model
*/
def loadModelFromHadoopFile(modelPath: String)(implicit sparkContext: SparkContext):
XGBoostModel = {
val path = new Path(modelPath)
val dataInStream = path.getFileSystem(sparkContext.hadoopConfiguration).open(path)
val modelType = dataInStream.readUTF()
val (featureCol, labelCol, predictionCol) = loadGeneralModelParams(dataInStream)
modelType match {
case "_cls_" =>
val rawPredictionCol = dataInStream.readUTF()
val numClasses = dataInStream.readInt()
val thresholdLength = dataInStream.readInt()
var thresholds: Array[Double] = null
if (thresholdLength != -1) {
thresholds = new Array[Double](thresholdLength)
for (i <- 0 until thresholdLength) {
thresholds(i) = dataInStream.readDouble()
}
}
val xgBoostModel = new XGBoostClassificationModel(SXGBoost.loadModel(dataInStream))
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel).
asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(rawPredictionCol)
if (thresholdLength != -1) {
xgBoostModel.setThresholds(thresholds)
}
xgBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClasses
xgBoostModel
case "_reg_" =>
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
case other =>
throw new XGBoostError(s"Unknown model type $other. Supported types " +
s"are: ['_reg_', '_cls_'].")
}
}
}
private class Watches private(
@ -489,12 +308,29 @@ private class Watches private(
private object Watches {
def buildGroups(groups: Seq[Int]): Seq[Int] = {
val output = mutable.ArrayBuffer.empty[Int]
var count = 1
var i = 1
while (i < groups.length) {
if (groups(i) != groups(i - 1)) {
output += count
count = 1
} else {
count += 1
}
i += 1
}
output += count
output
}
def apply(
params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint],
baseMarginsOpt: Option[Array[Float]],
cacheDirName: Option[String]): Watches = {
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
val r = new Random(seed)
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
@ -506,8 +342,18 @@ private object Watches {
accepted
}
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
val (trainIter1, trainIter2) = trainPoints.duplicate
val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull)
val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray
trainMatrix.setGroup(trainGroups)
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
if (trainTestRatio < 1.0) {
val testGroups = buildGroups(testPoints.map(_.group)).toArray
testMatrix.setGroup(testGroups)
}
r.setSeed(seed)
for (baseMargins <- baseMarginsOpt) {
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
@ -515,11 +361,6 @@ private object Watches {
testMatrix.setBaseMargin(testMargin)
}
// TODO: use group attribute from the points.
if (params.contains("groupData") && params("groupData") != null) {
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
TaskContext.getPartitionId()).toArray)
}
new Watches(trainMatrix, testMatrix, cacheDirName)
}
}

View File

@ -1,181 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, Param, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset}
/**
* class of the XGBoost model used for classification task
*/
class XGBoostClassificationModel private[spark](
override val uid: String, booster: Booster)
extends XGBoostModel(booster) {
def this(booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), booster)
// only called in copy()
def this(uid: String) = this(uid, null)
// scalastyle:off
/**
* whether to output raw margin
*/
final val outputMargin = new BooleanParam(this, "outputMargin", "whether to output untransformed margin value")
setDefault(outputMargin, false)
def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel]
/**
* the name of the column storing the raw prediction value, either probabilities (as default) or
* raw margin value
*/
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).")
setDefault(rawPredictionCol, "probabilities")
final def getRawPredictionCol: String = $(rawPredictionCol)
def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel]
/**
* Thresholds in multi-class classification
*/
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
def getThresholds: Array[Double] = $(thresholds)
def setThresholds(value: Array[Double]): XGBoostClassificationModel =
set(thresholds, value).asInstanceOf[XGBoostClassificationModel]
// scalastyle:on
// generate dataframe containing raw prediction column which is typed as Vector
private def predictRaw(
testSet: Dataset[_],
temporalColName: Option[String] = None,
forceTransformedScore: Option[Boolean] = None): DataFrame = {
val predictRDD = produceRowRDD(testSet, forceTransformedScore.getOrElse($(outputMargin)))
val colName = temporalColName.getOrElse($(rawPredictionCol))
val tempColName = colName + "_arraytype"
val dsWithArrayTypedRawPredCol = testSet.sparkSession.createDataFrame(predictRDD, schema = {
testSet.schema.add(tempColName, ArrayType(FloatType, containsNull = false))
})
val transformerForProbabilitiesArray =
(rawPredArray: mutable.WrappedArray[Float]) =>
if (numClasses == 2) {
Array(1 - rawPredArray(0), rawPredArray(0)).map(_.toDouble)
} else {
rawPredArray.map(_.toDouble).array
}
dsWithArrayTypedRawPredCol.withColumn(colName,
udf((rawPredArray: mutable.WrappedArray[Float]) =>
new MLDenseVector(transformerForProbabilitiesArray(rawPredArray))).apply(col(tempColName))).
drop(tempColName)
}
private def fromFeatureToPrediction(testSet: Dataset[_]): Dataset[_] = {
val rawPredictionDF = predictRaw(testSet, Some("rawPredictionCol"))
val predictionUDF = udf(raw2prediction _).apply(col("rawPredictionCol"))
val tempDF = rawPredictionDF.withColumn($(predictionCol), predictionUDF)
val allColumnNames = testSet.columns ++ Seq($(predictionCol))
tempDF.select(allColumnNames(0), allColumnNames.tail: _*)
}
private def argMax(vector: Array[Double]): Double = {
vector.zipWithIndex.maxBy(_._1)._2
}
private def raw2prediction(rawPrediction: MLDenseVector): Double = {
if (!isDefined(thresholds)) {
argMax(rawPrediction.values)
} else {
probability2prediction(rawPrediction)
}
}
private def probability2prediction(probability: MLDenseVector): Double = {
if (!isDefined(thresholds)) {
argMax(probability.values)
} else {
val thresholds: Array[Double] = getThresholds
val scaledProbability =
probability.values.zip(thresholds).map { case (p, t) =>
if (t == 0.0) Double.PositiveInfinity else p / t
}
argMax(scaledProbability)
}
}
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
transformSchema(testSet.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".transform() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
if ($(outputMargin)) {
setRawPredictionCol("margin")
}
var outputData = testSet
var numColsOutput = 0
if ($(rawPredictionCol).nonEmpty) {
outputData = predictRaw(testSet)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
if ($(rawPredictionCol).nonEmpty) {
require(!$(outputMargin), "XGBoost does not support output final prediction with" +
" untransformed margin. Please set predictionCol as \"\" when setting outputMargin as" +
" true")
val rawToPredUDF = udf(raw2prediction _).apply(col($(rawPredictionCol)))
outputData = outputData.withColumn($(predictionCol), rawToPredUDF)
} else {
outputData = fromFeatureToPrediction(testSet)
}
numColsOutput += 1
}
if (numColsOutput == 0) {
this.logWarning(s"$uid: XGBoostClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
outputData.toDF()
}
private[spark] var numOfClasses = 2
def numClasses: Int = numOfClasses
override def copy(extra: ParamMap): XGBoostClassificationModel = {
val newModel = copyValues(new XGBoostClassificationModel(booster), extra)
newModel.setSummary(summary)
}
override protected def predict(features: MLVector): Double = {
throw new Exception("XGBoost does not support online prediction ")
}
}

View File

@ -0,0 +1,432 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import scala.collection.mutable
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext
import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.json4s.DefaultFormats
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
class XGBoostClassifier (
override val uid: String,
private val xgboostParams: Map[String, Any])
extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
with XGBoostClassifierParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("xgbc"), Map[String, Any]())
def this(uid: String) = this(uid, Map[String, Any]())
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbc"), xgboostParams)
XGBoostToMLlibParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
def setNumClass(value: Int): this.type = set(numClass, value)
// setters for general params
def setNumRound(value: Int): this.type = set(numRound, value)
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
def setNthread(value: Int): this.type = set(nthread, value)
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
def setSilent(value: Int): this.type = set(silent, value)
def setMissing(value: Float): this.type = set(missing, value)
def setTimeoutRequestWorkers(value: Long): this.type = set(timeoutRequestWorkers, value)
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
def setSeed(value: Long): this.type = set(seed, value)
// setters for booster params
def setBooster(value: String): this.type = set(booster, value)
def setEta(value: Double): this.type = set(eta, value)
def setGamma(value: Double): this.type = set(gamma, value)
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
def setSubsample(value: Double): this.type = set(subsample, value)
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
def setLambda(value: Double): this.type = set(lambda, value)
def setAlpha(value: Double): this.type = set(alpha, value)
def setTreeMethod(value: String): this.type = set(treeMethod, value)
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)
def setSketchEps(value: Double): this.type = set(sketchEps, value)
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
def setSampleType(value: String): this.type = set(sampleType, value)
def setNormalizeType(value: String): this.type = set(normalizeType, value)
def setRateDrop(value: Double): this.type = set(rateDrop, value)
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
// setters for learning params
def setObjective(value: String): this.type = set(objective, value)
def setBaseScore(value: Double): this.type = set(baseScore, value)
def setEvalMetric(value: String): this.type = set(evalMetric, value)
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
if ($(objective).startsWith("multi")) {
// multi
"merror"
} else {
// binary
"error"
}
}
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
set(evalMetric, setupDefaultEvalMetric())
}
val _numClasses = getNumClasses(dataset)
if (isDefined(numClass) && $(numClass) != _numClasses) {
throw new Exception("The number of classes in dataset doesn't match " +
"\'num_class\' in xgboost params.")
}
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
lit(Float.NaN)
} else {
col($(baseMarginCol))
}
val instances: RDD[XGBLabeledPoint] = dataset.select(
col($(featuresCol)),
col($(labelCol)).cast(FloatType),
baseMargin.cast(FloatType),
weight.cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, baseMargin = baseMargin, weight = weight)
}
transformSchema(dataset.schema, logging = true)
val derivedXGBParamMap = MLlib2XGBoostParams
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing))
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)
model
}
override def copy(extra: ParamMap): XGBoostClassifier = defaultCopy(extra)
}
object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
override def load(path: String): XGBoostClassifier = super.load(path)
}
class XGBoostClassificationModel private[ml](
override val uid: String,
override val numClasses: Int,
private[spark] val _booster: Booster)
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
with XGBoostClassifierParams with MLWritable with Serializable {
import XGBoostClassificationModel._
// only called in copy()
def this(uid: String) = this(uid, 2, null)
private var trainingSummary: Option[XGBoostTrainingSummary] = None
/**
* Returns summary (e.g. train/test objective history) of model on the
* training set. An exception is thrown if no summary is available.
*/
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
throw new IllegalStateException("No training summary available for this XGBoostModel")
}
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
trainingSummary = Some(summary)
this
}
// TODO: Make it public after we resolve performance issue
private def margin(features: Vector): Array[Float] = {
import DataUtils._
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
_booster.predict(data = dm, outPutMargin = true)(0)
}
private def probability(features: Vector): Array[Float] = {
import DataUtils._
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
_booster.predict(data = dm, outPutMargin = false)(0)
}
override def predict(features: Vector): Double = {
throw new Exception("XGBoost-Spark does not support online prediction")
}
// Actually we don't use this function at all, to make it pass compiler check.
override def predictRaw(features: Vector): Vector = {
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
}
// Actually we don't use this function at all, to make it pass compiler check.
override def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
}
// Generate raw prediction and probability prediction.
private def transformInternal(dataset: Dataset[_]): DataFrame = {
val schema = StructType(dataset.schema.fields ++
Seq(StructField(name = _rawPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)) ++
Seq(StructField(name = _probabilityCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
$(featuresCol))).toList.iterator
import DataUtils._
val cacheInfo = {
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
try {
val rawPredictionItr = {
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator
}
val probabilityItr = {
bBooster.value.predict(dm, outPutMargin = false).map(Row(_)).iterator
}
Rabit.shutdown()
rowItr1.zip(rawPredictionItr).zip(probabilityItr).map {
case ((originals: Row, rawPrediction: Row), probability: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
}
} finally {
dm.delete()
}
} else {
Iterator[Row]()
}
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rdd, schema)
}
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".transform() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = transformInternal(dataset)
var numColsOutput = 0
val rawPredictionUDF = udf { (rawPrediction: mutable.WrappedArray[Float]) =>
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
}
val probabilityUDF = udf { (probability: mutable.WrappedArray[Float]) =>
if (numClasses == 2) {
Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble))
} else {
Vectors.dense(probability.map(_.toDouble).toArray)
}
}
val predictUDF = udf { (probability: mutable.WrappedArray[Float]) =>
// From XGBoost probability to MLlib prediction
val probabilities = if (numClasses == 2) {
Array(1 - probability(0), probability(0)).map(_.toDouble)
} else {
probability.map(_.toDouble).toArray
}
probability2prediction(Vectors.dense(probabilities))
}
if ($(rawPredictionCol).nonEmpty) {
outputData = outputData
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
numColsOutput += 1
}
if ($(probabilityCol).nonEmpty) {
outputData = outputData
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
}
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
outputData
.toDF
.drop(col(_rawPredictionCol))
.drop(col(_probabilityCol))
}
override def copy(extra: ParamMap): XGBoostClassificationModel = {
val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra)
newModel.setSummary(summary).setParent(parent)
}
override def write: MLWriter =
new XGBoostClassificationModel.XGBoostClassificationModelWriter(this)
}
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
private val _rawPredictionCol = "_rawPrediction"
private val _probabilityCol = "_probability"
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
override def load(path: String): XGBoostClassificationModel = super.load(path)
private[XGBoostClassificationModel]
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
outputStream.writeInt(instance.numClasses)
instance._booster.saveModel(outputStream)
outputStream.close()
}
}
private class XGBoostClassificationModelReader extends MLReader[XGBoostClassificationModel] {
/** Checked against metadata when loading model */
private val className = classOf[XGBoostClassificationModel].getName
override def load(path: String): XGBoostClassificationModel = {
implicit val sc = super.sparkSession.sparkContext
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
val numClasses = dataInStream.readInt()
val booster = SXGBoost.loadModel(dataInStream)
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
model
}
}
}

View File

@ -1,186 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.Predictor
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.{Dataset, Row}
import org.json4s.DefaultFormats
/**
* XGBoost Estimator to produce a XGBoost model
*/
class XGBoostEstimator private[spark](
override val uid: String, xgboostParams: Map[String, Any])
extends Predictor[Vector, XGBoostEstimator, XGBoostModel]
with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {
def this(xgboostParams: Map[String, Any]) =
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
def this(uid: String) = this(uid, Map[String, Any]())
// called in fromXGBParamMapToParams only when eval_metric is not defined
private def setupDefaultEvalMetric(): String = {
val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null))
if (objFunc == null) {
"rmse"
} else {
// compute default metric based on specified objective
val isClassificationTask = XGBoost.isClassificationTask(xgboostParams)
if (!isClassificationTask) {
// default metric for regression or ranking
if (objFunc.toString.startsWith("rank")) {
"map"
} else {
"rmse"
}
} else {
// default metric for classification
if (objFunc.toString.startsWith("multi")) {
// multi
"merror"
} else {
// binary
"error"
}
}
}
}
private def fromXGBParamMapToParams(): Unit = {
for ((paramName, paramValue) <- xgboostParams) {
params.find(_.name == paramName) match {
case None =>
case Some(_: DoubleParam) =>
set(paramName, paramValue.toString.toDouble)
case Some(_: BooleanParam) =>
set(paramName, paramValue.toString.toBoolean)
case Some(_: IntParam) =>
set(paramName, paramValue.toString.toInt)
case Some(_: FloatParam) =>
set(paramName, paramValue.toString.toFloat)
case Some(_: Param[_]) =>
set(paramName, paramValue)
}
}
if (xgboostParams.get("eval_metric").isEmpty) {
set("eval_metric", setupDefaultEvalMetric())
}
}
fromXGBParamMapToParams()
private[spark] def fromParamsToXGBParamMap: Map[String, Any] = {
val xgbParamMap = new mutable.HashMap[String, Any]()
for (param <- params) {
xgbParamMap += param.name -> $(param)
}
val r = xgbParamMap.toMap
if (!XGBoost.isClassificationTask(r) || $(numClasses) == 2) {
r - "num_class"
} else {
r
}
}
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
var newTrainingSet = trainingSet
if (!trainingSet.columns.contains($(baseMarginCol))) {
newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
}
if (!trainingSet.columns.contains($(weightCol))) {
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
}
newTrainingSet
}
/**
* produce a XGBoostModel by fitting the given dataset
*/
override def train(trainingSet: Dataset[_]): XGBoostModel = {
val instances = ensureColumns(trainingSet).select(
col($(featuresCol)),
col($(labelCol)).cast(FloatType),
col($(baseMarginCol)).cast(FloatType),
col($(weightCol)).cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
}
transformSchema(trainingSet.schema, logging = true)
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
val trainedModel = XGBoost.trainDistributed(instances, derivedXGBoosterParamMap,
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing)).setParent(this)
val returnedModel = copyValues(trainedModel, extractParamMap())
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
}
returnedModel
}
override def copy(extra: ParamMap): XGBoostEstimator = {
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
}
override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this)
}
object XGBoostEstimator extends MLReadable[XGBoostEstimator] {
override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader
override def load(path: String): XGBoostEstimator = super.load(path)
private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
require(instance.fromParamsToXGBParamMap("custom_eval") == null &&
instance.fromParamsToXGBParamMap("custom_obj") == null,
"we do not support persist XGBoostEstimator with customized evaluator and objective" +
" function for now")
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
}
}
private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] {
override def load(path: String): XGBoostEstimator = {
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
DefaultXGBoostParamsReader.getAndSetParams(instance, metadata)
instance.asInstanceOf[XGBoostEstimator]
}
}
}

View File

@ -1,387 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.apache.spark.ml.PredictionModel
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
import org.apache.spark.ml.param.{BooleanParam, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.types.{ArrayType, FloatType}
import org.apache.spark.{SparkContext, TaskContext}
import org.json4s.DefaultFormats
/**
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
*/
abstract class XGBoostModel(protected var _booster: Booster)
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
with Params with MLWritable {
private var trainingSummary: Option[XGBoostTrainingSummary] = None
/**
* Returns summary (e.g. train/test objective history) of model on the
* training set. An exception is thrown if no summary is available.
*/
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
throw new IllegalStateException("No training summary available for this XGBoostModel")
}
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
trainingSummary = Some(summary)
this
}
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
// scalastyle:off
final val useExternalMemory = new BooleanParam(this, "use_external_memory",
"whether to use external memory for prediction")
setDefault(useExternalMemory, false)
def setExternalMemory(value: Boolean): XGBoostModel = set(useExternalMemory, value)
// scalastyle:on
/**
* Predict leaf instances with the given test set (represented as RDD)
*
* @param testSet test set represented as RDD
*/
def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Float]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
if (testSamples.nonEmpty) {
val dMatrix = new DMatrix(testSamples.map(_.asXGB))
try {
broadcastBooster.value.predictLeaf(dMatrix).iterator
} finally {
Rabit.shutdown()
dMatrix.delete()
}
} else {
Iterator()
}
}
}
/**
* evaluate XGBoostModel with a RDD-wrapped dataset
*
* NOTE: you have to specify value of either eval or iter; when you specify both, this method
* adopts the default eval metric of model
*
* @param evalDataset the dataset used for evaluation
* @param evalName the name of evaluation
* @param evalFunc the customized evaluation function, null by default to use the default metric
* of model
* @param iter the current iteration, -1 to be null to use customized evaluation functions
* @param groupData group data specify each group size for ranking task. Top level corresponds
* to partition id, second level is the group sizes.
* @return the average metric over all partitions
*/
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
iter: Int = -1, useExternalCache: Boolean = false,
groupData: Seq[Seq[Int]] = null): String = {
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
val appName = evalDataset.context.appName
val allEvalMetrics = evalDataset.mapPartitions {
labeledPointsPartition =>
import DataUtils._
if (labeledPointsPartition.hasNext) {
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
val cacheFileName = {
if (broadcastUseExternalCache.value) {
s"$appName-${TaskContext.get().stageId()}-$evalName" +
s"-deval_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val dMatrix = new DMatrix(labeledPointsPartition.map(_.asXGB), cacheFileName)
try {
if (groupData != null) {
dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray)
}
(evalFunc, iter) match {
case (null, _) => {
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
val Array(evName, predNumeric) = predStr.split(":")
Iterator(Some(evName, predNumeric.toFloat))
}
case _ => {
val predictions = broadcastBooster.value.predict(dMatrix)
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
}
}
} finally {
Rabit.shutdown()
dMatrix.delete()
}
} else {
Iterator(None)
}
}.filter(_.isDefined).collect()
val evalPrefix = allEvalMetrics.map(_.get._1).head
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
s"$evalPrefix = $evalMetricMean"
}
/**
* Predict result with the given test set (represented as RDD)
*
* @param testSet test set represented as RDD
* @param missingValue the specified value to represent the missing value
*/
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val sampleArray = testSamples.toArray
val numRows = sampleArray.length
if (numRows == 0) {
Iterator()
} else {
val numColumns = sampleArray.head.size
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
// translate to required format
val flatSampleArray = new Array[Float](numRows * numColumns)
for (i <- flatSampleArray.indices) {
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
}
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
try {
broadcastBooster.value.predict(dMatrix).iterator
} finally {
Rabit.shutdown()
dMatrix.delete()
}
}
}
}
/**
* Predict result with the given test set (represented as RDD)
*
* @param testSet test set represented as RDD
* @param useExternalCache whether to use external cache for the test set
* @param outputMargin whether to output raw untransformed margin value
*/
def predict(
testSet: RDD[MLVector],
useExternalCache: Boolean = false,
outputMargin: Boolean = false): RDD[Array[Float]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
val appName = testSet.context.appName
testSet.mapPartitions { testSamples =>
if (testSamples.nonEmpty) {
import DataUtils._
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val cacheFileName = {
if (useExternalCache) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val dMatrix = new DMatrix(testSamples.map(_.asXGB), cacheFileName)
try {
broadcastBooster.value.predict(dMatrix).iterator
} finally {
Rabit.shutdown()
dMatrix.delete()
}
} else {
Iterator()
}
}
}
protected def transformImpl(testSet: Dataset[_]): DataFrame
/**
* append leaf index of each row as an additional column in the original dataset
*
* @return the original dataframe with an additional column containing prediction results
*/
def transformLeaf(testSet: Dataset[_]): DataFrame = {
val predictRDD = produceRowRDD(testSet, predLeaf = true)
setPredictionCol("predLeaf")
transformSchema(testSet.schema, logging = true)
testSet.sparkSession.createDataFrame(predictRDD, testSet.schema.add($(predictionCol),
ArrayType(FloatType, containsNull = false)))
}
protected def produceRowRDD(testSet: Dataset[_], outputMargin: Boolean = false,
predLeaf: Boolean = false): RDD[Row] = {
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
val appName = testSet.sparkSession.sparkContext.appName
testSet.rdd.mapPartitions {
rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[MLVector](
$(featuresCol))).toList.iterator
import DataUtils._
val cachePrefix = {
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val testDataset = new DMatrix(vectorIterator.map(_.asXGB), cachePrefix)
try {
val rawPredictResults = {
if (!predLeaf) {
broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator
} else {
broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator
}
}
Rabit.shutdown()
// concatenate original data partition and predictions
rowItr1.zip(rawPredictResults).map {
case (originalColumns: Row, predictColumn: Row) =>
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
}
} finally {
testDataset.delete()
}
} else {
Iterator[Row]()
}
}
}
/**
* produces the prediction results and append as an additional column in the original dataset
* NOTE: the prediction results is kept as the original format of xgboost
*
* @return the original dataframe with an additional column containing prediction results
*/
override def transform(testSet: Dataset[_]): DataFrame = {
transformImpl(testSet)
}
private def saveGeneralModelParam(outputStream: FSDataOutputStream): Unit = {
outputStream.writeUTF(getFeaturesCol)
outputStream.writeUTF(getLabelCol)
outputStream.writeUTF(getPredictionCol)
}
/**
* Save the model as to HDFS-compatible file system.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = {
val path = new Path(modelPath)
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
// output model type
this match {
case model: XGBoostClassificationModel =>
outputStream.writeUTF("_cls_")
saveGeneralModelParam(outputStream)
outputStream.writeUTF(model.getRawPredictionCol)
outputStream.writeInt(model.numClasses)
// threshold
// threshold length
if (!isDefined(model.thresholds)) {
outputStream.writeInt(-1)
} else {
val thresholdLength = model.getThresholds.length
outputStream.writeInt(thresholdLength)
for (i <- 0 until thresholdLength) {
outputStream.writeDouble(model.getThresholds(i))
}
}
case model: XGBoostRegressionModel =>
outputStream.writeUTF("_reg_")
// eventual prediction col
saveGeneralModelParam(outputStream)
}
// booster
_booster.saveModel(outputStream)
outputStream.close()
}
def booster: Booster = _booster
def version: Int = this.booster.booster.getVersion
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)
}
object XGBoostModel extends MLReadable[XGBoostModel] {
private[spark] def apply(booster: Booster, isClassification: Boolean): XGBoostModel = {
if (!isClassification) {
new XGBoostRegressionModel(booster)
} else {
new XGBoostClassificationModel(booster)
}
}
override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader
override def load(path: String): XGBoostModel = super.load(path)
private[XGBoostModel] class XGBoostModelModelWriter(instance: XGBoostModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.saveModelAsHadoopFile(dataPath)
}
}
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
override def load(path: String): XGBoostModel = {
implicit val sc = super.sparkSession.sparkContext
val dataPath = new Path(path, "data").toString
// not used / all data resides in platform independent xgboost model file
// val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
XGBoost.loadModelFromHadoopFile(dataPath)
}
}
}

View File

@ -1,61 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.ml.linalg.{Vector => MLVector}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, FloatType}
/**
* class of XGBoost model used for regression task
*/
class XGBoostRegressionModel private[spark](override val uid: String, booster: Booster)
extends XGBoostModel(booster) {
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostRegressionModel"), _booster)
// only called in copy()
def this(uid: String) = this(uid, null)
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
transformSchema(testSet.schema, logging = true)
val predictRDD = produceRowRDD(testSet)
val tempPredColName = $(predictionCol) + "_temp"
val transformerForArrayTypedPredCol =
udf((regressionResults: mutable.WrappedArray[Float]) => regressionResults(0))
testSet.sparkSession.createDataFrame(predictRDD,
schema = testSet.schema.add(tempPredColName, ArrayType(FloatType, containsNull = false))
).withColumn(
$(predictionCol),
transformerForArrayTypedPredCol.apply(col(tempPredColName))).drop(tempPredColName)
}
override protected def predict(features: MLVector): Double = {
throw new Exception("XGBoost does not support online prediction for now")
}
override def copy(extra: ParamMap): XGBoostRegressionModel = {
val newModel = copyValues(new XGBoostRegressionModel(booster), extra)
newModel.setSummary(summary)
}
}

View File

@ -0,0 +1,356 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.json4s.DefaultFormats
import scala.collection.mutable
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
with ParamMapFuncs
class XGBoostRegressor (
override val uid: String,
private val xgboostParams: Map[String, Any])
extends Predictor[Vector, XGBoostRegressor, XGBoostRegressionModel]
with XGBoostRegressorParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("xgbr"), Map[String, Any]())
def this(uid: String) = this(uid, Map[String, Any]())
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbr"), xgboostParams)
XGBoostToMLlibParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
def setGroupCol(value: String): this.type = set(groupCol, value)
// setters for general params
def setNumRound(value: Int): this.type = set(numRound, value)
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
def setNthread(value: Int): this.type = set(nthread, value)
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
def setSilent(value: Int): this.type = set(silent, value)
def setMissing(value: Float): this.type = set(missing, value)
def setTimeoutRequestWorkers(value: Long): this.type = set(timeoutRequestWorkers, value)
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
def setSeed(value: Long): this.type = set(seed, value)
// setters for booster params
def setBooster(value: String): this.type = set(booster, value)
def setEta(value: Double): this.type = set(eta, value)
def setGamma(value: Double): this.type = set(gamma, value)
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
def setSubsample(value: Double): this.type = set(subsample, value)
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
def setLambda(value: Double): this.type = set(lambda, value)
def setAlpha(value: Double): this.type = set(alpha, value)
def setTreeMethod(value: String): this.type = set(treeMethod, value)
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)
def setSketchEps(value: Double): this.type = set(sketchEps, value)
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
def setSampleType(value: String): this.type = set(sampleType, value)
def setNormalizeType(value: String): this.type = set(normalizeType, value)
def setRateDrop(value: Double): this.type = set(rateDrop, value)
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
// setters for learning params
def setObjective(value: String): this.type = set(objective, value)
def setBaseScore(value: Double): this.type = set(baseScore, value)
def setEvalMetric(value: String): this.type = set(evalMetric, value)
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
if ($(objective).startsWith("rank")) {
"map"
} else {
"rmse"
}
}
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
set(evalMetric, setupDefaultEvalMetric())
}
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
lit(Float.NaN)
} else {
col($(baseMarginCol))
}
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
val instances: RDD[XGBLabeledPoint] = dataset.select(
col($(labelCol)).cast(FloatType),
col($(featuresCol)),
weight.cast(FloatType),
group.cast(IntegerType),
baseMargin.cast(FloatType)
).rdd.map {
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
}
transformSchema(dataset.schema, logging = true)
val derivedXGBParamMap = MLlib2XGBoostParams
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing))
val model = new XGBoostRegressionModel(uid, _booster)
val summary = XGBoostTrainingSummary(_metrics)
model.setSummary(summary)
model
}
override def copy(extra: ParamMap): XGBoostRegressor = defaultCopy(extra)
}
object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
override def load(path: String): XGBoostRegressor = super.load(path)
}
class XGBoostRegressionModel private[ml] (
override val uid: String,
private[spark] val _booster: Booster)
extends PredictionModel[Vector, XGBoostRegressionModel]
with XGBoostRegressorParams with MLWritable with Serializable {
import XGBoostRegressionModel._
// only called in copy()
def this(uid: String) = this(uid, null)
private var trainingSummary: Option[XGBoostTrainingSummary] = None
/**
* Returns summary (e.g. train/test objective history) of model on the
* training set. An exception is thrown if no summary is available.
*/
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
throw new IllegalStateException("No training summary available for this XGBoostModel")
}
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
trainingSummary = Some(summary)
this
}
override def predict(features: Vector): Double = {
throw new Exception("XGBoost-Spark does not support online prediction")
}
private def transformInternal(dataset: Dataset[_]): DataFrame = {
val schema = StructType(dataset.schema.fields ++
Seq(StructField(name = _originalPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
$(featuresCol))).toList.iterator
import DataUtils._
val cacheInfo = {
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
try {
val originalPredictionItr = {
bBooster.value.predict(dm).map(Row(_)).iterator
}
Rabit.shutdown()
rowItr1.zip(originalPredictionItr).map {
case (originals: Row, originalPrediction: Row) =>
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
}
} finally {
dm.delete()
}
} else {
Iterator[Row]()
}
}
bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rdd, schema)
}
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = transformInternal(dataset)
var numColsOutput = 0
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
originalPrediction(0).toDouble
}
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_originalPredictionCol)))
numColsOutput += 1
}
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
outputData.toDF.drop(col(_originalPredictionCol))
}
override def copy(extra: ParamMap): XGBoostRegressionModel = {
val newModel = copyValues(new XGBoostRegressionModel(uid, _booster), extra)
newModel.setSummary(summary).setParent(parent)
}
override def write: MLWriter =
new XGBoostRegressionModel.XGBoostRegressionModelWriter(this)
}
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
private val _originalPredictionCol = "_originalPrediction"
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
override def load(path: String): XGBoostRegressionModel = super.load(path)
private[XGBoostRegressionModel]
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream)
outputStream.close()
}
}
private class XGBoostRegressionModelReader extends MLReader[XGBoostRegressionModel] {
/** Checked against metadata when loading model */
private val className = classOf[XGBoostRegressionModel].getName
override def load(path: String): XGBoostRegressionModel = {
implicit val sc = super.sparkSession.sparkContext
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
val booster = SXGBoost.loadModel(dataInStream)
val model = new XGBoostRegressionModel(metadata.uid, booster)
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
model
}
}
}

View File

@ -20,40 +20,48 @@ import scala.collection.immutable.HashSet
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
trait BoosterParams extends Params {
private[spark] trait BoosterParams extends Params {
/**
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
*/
val boosterType = new Param[String](this, "booster",
final val booster = new Param[String](this, "booster",
s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}",
(value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase))
final def getBooster: String = $(booster)
/**
* step size shrinkage used in update to prevents overfitting. After each boosting step, we
* can directly get the weights of new features and eta actually shrinks the feature weights
* to make the boosting process more conservative. [default=0.3] range: [0,1]
*/
val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
" overfitting. After each boosting step, we can directly get the weights of new features." +
" and eta actually shrinks the feature weights to make the boosting process more conservative.",
(value: Double) => value >= 0 && value <= 1)
final def getEta: Double = $(eta)
/**
* minimum loss reduction required to make a further partition on a leaf node of the tree.
* the larger, the more conservative the algorithm will be. [default=0] range: [0,
* Double.MaxValue]
*/
val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a further" +
" partition on a leaf node of the tree. the larger, the more conservative the algorithm" +
" will be.", (value: Double) => value >= 0)
final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " +
"further partition on a leaf node of the tree. the larger, the more conservative the " +
"algorithm will be.", (value: Double) => value >= 0)
final def getGamma: Double = $(gamma)
/**
* maximum depth of a tree, increase this value will make model more complex / likely to be
* overfitting. [default=6] range: [1, Int.MaxValue]
*/
val maxDepth = new IntParam(this, "max_depth", "maximum depth of a tree, increase this value" +
" will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
final def getMaxDepth: Int = $(maxDepth)
/**
* minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
@ -62,13 +70,15 @@ trait BoosterParams extends Params {
* to minimum number of instances needed to be in each node. The larger, the more conservative
* the algorithm will be. [default=1] range: [0, Double.MaxValue]
*/
val minChildWeight = new DoubleParam(this, "min_child_weight", "minimum sum of instance" +
final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" +
" weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
" the sum of instance weight less than min_child_weight, then the building process will" +
" give up further partitioning. In linear regression mode, this simply corresponds to minimum" +
" number of instances needed to be in each node. The larger, the more conservative" +
" the algorithm will be.", (value: Double) => value >= 0)
final def getMinChildWeight: Double = $(minChildWeight)
/**
* Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it
* means there is no constraint. If it is set to a positive value, it can help making the update
@ -76,90 +86,113 @@ trait BoosterParams extends Params {
* regression when class is extremely imbalanced. Set it to value of 1-10 might help control the
* update. [default=0] range: [0, Double.MaxValue]
*/
val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow each" +
" tree's weight" +
final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " +
"each tree's weight" +
" estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
" to a positive value, it can help making the update step more conservative. Usually this" +
" parameter is not needed, but it might help in logistic regression when class is extremely" +
" imbalanced. Set it to value of 1-10 might help control the update",
(value: Double) => value >= 0)
final def getMaxDeltaStep: Double = $(maxDeltaStep)
/**
* subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly
* collected half of the data instances to grow trees and this will prevent overfitting.
* [default=1] range:(0,1]
*/
val subSample = new DoubleParam(this, "subsample", "subsample ratio of the training instance." +
" Setting it to 0.5 means that XGBoost randomly collected half of the data instances to" +
" grow trees and this will prevent overfitting.", (value: Double) => value <= 1 && value > 0)
final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " +
"instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " +
"instances to grow trees and this will prevent overfitting.",
(value: Double) => value <= 1 && value > 0)
final def getSubsample: Double = $(subsample)
/**
* subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
*/
val colSampleByTree = new DoubleParam(this, "colsample_bytree", "subsample ratio of columns" +
" when constructing each tree.", (value: Double) => value <= 1 && value > 0)
final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " +
"columns when constructing each tree.", (value: Double) => value <= 1 && value > 0)
final def getColsampleBytree: Double = $(colsampleBytree)
/**
* subsample ratio of columns for each split, in each level. [default=1] range: (0,1]
*/
val colSampleByLevel = new DoubleParam(this, "colsample_bylevel", "subsample ratio of columns" +
" for each split, in each level.", (value: Double) => value <= 1 && value > 0)
final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " +
"columns for each split, in each level.", (value: Double) => value <= 1 && value > 0)
final def getColsampleBylevel: Double = $(colsampleBylevel)
/**
* L2 regularization term on weights, increase this value will make model more conservative.
* [default=1]
*/
val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, increase this" +
" value will make model more conservative.", (value: Double) => value >= 0)
final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " +
"increase this value will make model more conservative.", (value: Double) => value >= 0)
final def getLambda: Double = $(lambda)
/**
* L1 regularization term on weights, increase this value will make model more conservative.
* [default=0]
*/
val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase this" +
" value will make model more conservative.", (value: Double) => value >= 0)
final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " +
"this value will make model more conservative.", (value: Double) => value >= 0)
final def getAlpha: Double = $(alpha)
/**
* The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'}
* [default='auto']
*/
val treeMethod = new Param[String](this, "tree_method",
final val treeMethod = new Param[String](this, "treeMethod",
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
final def getTreeMethod: String = $(treeMethod)
/**
* growth policy for fast histogram algorithm
*/
val growthPolicty = new Param[String](this, "grow_policy",
final val growPolicy = new Param[String](this, "growPolicy",
"growth policy for fast histogram algorithm",
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
final def getGrowPolicy: String = $(growPolicy)
/**
* maximum number of bins in histogram
*/
val maxBins = new IntParam(this, "max_bin", "maximum number of bins in histogram",
final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram",
(value: Int) => value > 0)
final def getMaxBins: Int = $(maxBins)
/**
* This is only used for approximate greedy algorithm.
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
* number of bins, this comes with theoretical guarantee with sketch accuracy.
* [default=0.03] range: (0, 1)
*/
val sketchEps = new DoubleParam(this, "sketch_eps",
final val sketchEps = new DoubleParam(this, "sketchEps",
"This is only used for approximate greedy algorithm. This roughly translated into" +
" O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" +
" with theoretical guarantee with sketch accuracy.",
(value: Double) => value < 1 && value > 0)
final def getSketchEps: Double = $(sketchEps)
/**
* Control the balance of positive and negative weights, useful for unbalanced classes. A typical
* value to consider: sum(negative cases) / sum(positive cases). [default=1]
*/
val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of positive" +
" and negative weights, useful for unbalanced classes. A typical value to consider:" +
final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " +
"positive and negative weights, useful for unbalanced classes. A typical value to consider:" +
" sum(negative cases) / sum(positive cases)")
final def getScalePosWeight: Double = $(scalePosWeight)
// Dart boosters
/**
@ -167,72 +200,59 @@ trait BoosterParams extends Params {
* Type of sampling algorithm. "uniform": dropped trees are selected uniformly.
* "weighted": dropped trees are selected in proportion to weight. [default="uniform"]
*/
val sampleType = new Param[String](this, "sample_type", "type of sampling algorithm, options:" +
" {'uniform', 'weighted'}",
final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " +
"options: {'uniform', 'weighted'}",
(value: String) => BoosterParams.supportedSampleType.contains(value))
final def getSampleType: String = $(sampleType)
/**
* Parameter of Dart booster.
* type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"]
*/
val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" +
" algorithm, options: {'tree', 'forest'}",
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
final def getNormalizeType: String = $(normalizeType)
/**
* Parameter of Dart booster.
* dropout rate. [default=0.0] range: [0.0, 1.0]
*/
val rateDrop = new DoubleParam(this, "rate_drop", "dropout rate", (value: Double) =>
final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) =>
value >= 0 && value <= 1)
final def getRateDrop: Double = $(rateDrop)
/**
* Parameter of Dart booster.
* probability of skip dropout. If a dropout is skipped, new trees are added in the same manner
* as gbtree. [default=0.0] range: [0.0, 1.0]
*/
val skipDrop = new DoubleParam(this, "skip_drop", "probability of skip dropout. If" +
final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" +
" a dropout is skipped, new trees are added in the same manner as gbtree.",
(value: Double) => value >= 0 && value <= 1)
final def getSkipDrop: Double = $(skipDrop)
// linear booster
/**
* Parameter of linear booster
* L2 regularization term on bias, default 0(no L1 reg on bias because it is not important)
*/
val lambdaBias = new DoubleParam(this, "lambda_bias", "L2 regularization term on bias, default" +
" 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " +
"default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
final def getLambdaBias: Double = $(lambdaBias)
setDefault(booster -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
minChildWeight -> 1, maxDeltaStep -> 0,
growthPolicty -> "depthwise", maxBins -> 16,
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
growPolicy -> "depthwise", maxBins -> 16,
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
/**
* Explains all params of this instance. See `explainParam()`.
*/
override def explainParams(): String = {
// TODO: filter some parameters according to the booster type
val boosterTypeStr = $(boosterType)
val validParamList = {
if (boosterTypeStr == "gblinear") {
// gblinear
params.filter(param => param.name == "lambda" ||
param.name == "alpha" || param.name == "lambda_bias")
} else if (boosterTypeStr != "dart") {
// gbtree
params.filter(param => param.name != "sample_type" &&
param.name != "normalize_type" && param.name != "rate_drop" && param.name != "skip_drop")
} else {
// dart
params.filter(_.name != "lambda_bias")
}
}
explainParam(boosterType) + "\n" ++ validParamList.map(explainParam).mkString("\n")
}
}
private[spark] object BoosterParams {

View File

@ -16,84 +16,104 @@
package ml.dmlc.xgboost4j.scala.spark.params
import com.google.common.base.CaseFormat
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import org.apache.spark.ml.param._
import scala.collection.mutable
trait GeneralParams extends Params {
private[spark] trait GeneralParams extends Params {
/**
* The number of rounds for boosting
*/
val round = new IntParam(this, "num_round", "The number of rounds for boosting",
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
ParamValidators.gtEq(1))
final def getNumRound: Int = $(numRound)
/**
* number of workers used to train xgboost model. default: 1
*/
val nWorkers = new IntParam(this, "nworkers", "number of workers used to run xgboost",
final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
ParamValidators.gtEq(1))
final def getNumWorkers: Int = $(numWorkers)
/**
* number of threads used by per worker. default 1
*/
val numThreadPerTask = new IntParam(this, "nthread", "number of threads used by per worker",
final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
ParamValidators.gtEq(1))
final def getNthread: Int = $(nthread)
/**
* whether to use external memory as cache. default: false
*/
val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external" +
"memory as cache")
final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
"whether to use external memory as cache")
final def getUseExternalMemory: Boolean = $(useExternalMemory)
/**
* 0 means printing running messages, 1 means silent mode. default: 0
*/
val silent = new IntParam(this, "silent",
final val silent = new IntParam(this, "silent",
"0 means printing running messages, 1 means silent mode.",
(value: Int) => value >= 0 && value <= 1)
final def getSilent: Int = $(silent)
/**
* customized objective function provided by user. default: null
*/
val customObj = new CustomObjParam(this, "custom_obj", "customized objective function " +
final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
"provided by user")
/**
* customized evaluation function provided by user. default: null
*/
val customEval = new CustomEvalParam(this, "custom_eval", "customized evaluation function " +
"provided by user")
final val customEval = new CustomEvalParam(this, "customEval",
"customized evaluation function provided by user")
/**
* the value treated as missing. default: Float.NaN
*/
val missing = new FloatParam(this, "missing", "the value treated as missing")
final val missing = new FloatParam(this, "missing", "the value treated as missing")
final def getMissing: Float = $(missing)
/**
* the maximum time to wait for the job requesting new workers. default: 30 minutes
*/
val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" +
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
" value is set smaller than or equal to 0.")
final val timeoutRequestWorkers = new LongParam(this, "timeoutRequestWorkers", "the maximum " +
"time to request new Workers if numCores are insufficient. The timeout will be disabled " +
"if this value is set smaller than or equal to 0.")
final def getTimeoutRequestWorkers: Long = $(timeoutRequestWorkers)
/**
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
*/
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
"save checkpoints. If there are existing checkpoints in checkpoint_path. The job will load " +
"the checkpoint with highest version as the starting point for training. If " +
final val checkpointPath = new Param[String](this, "checkpointPath", "the hdfs folder to load " +
"and save checkpoints. If there are existing checkpoints in checkpoint_path. The job will " +
"load the checkpoint with highest version as the starting point for training. If " +
"checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
final def getCheckpointPath: String = $(checkpointPath)
/**
* Param for set checkpoint interval (&gt;= 1) or disable checkpoint (-1). E.g. 10 means that
* the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
* also be set if the checkpoint interval is greater than 0.
*/
val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint " +
"interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained model will get " +
"checkpointed every 10 iterations. Note: `checkpoint_path` must also be set if the checkpoint" +
" interval is greater than 0.", (interval: Int) => interval == -1 || interval >= 1)
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval",
"set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained " +
"model will get checkpointed every 10 iterations. Note: `checkpoint_path` must also be " +
"set if the checkpoint interval is greater than 0.",
(interval: Int) => interval == -1 || interval >= 1)
final def getCheckpointInterval: Int = $(checkpointInterval)
/**
* Rabit tracker configurations. The parameter must be provided as an instance of the
@ -122,15 +142,87 @@ trait GeneralParams extends Params {
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
* Ignored if the tracker implementation is "python".
*/
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
/** Random seed for the C++ part of XGBoost and train/test splitting. */
val seed = new LongParam(this, "seed", "random seed")
final val seed = new LongParam(this, "seed", "random seed")
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
final def getSeed: Long = $(seed)
setDefault(numRound -> 1, numWorkers -> 1, nthread -> 1,
useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
checkpointPath -> "", checkpointInterval -> -1
)
}
trait HasBaseMarginCol extends Params {
/**
* Param for initial prediction (aka base margin) column name.
* @group param
*/
final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
"Initial prediction (aka base margin) column name.")
/** @group getParam */
final def getBaseMarginCol: String = $(baseMarginCol)
}
trait HasGroupCol extends Params {
/**
* Param for group column name.
* @group param
*/
final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
/** @group getParam */
final def getGroupCol: String = $(groupCol)
}
trait HasNumClass extends Params {
/**
* number of classes
*/
final val numClass = new IntParam(this, "numClass", "number of classes")
/** @group getParam */
final def getNumClass: Int = $(numClass)
}
private[spark] trait ParamMapFuncs extends Params {
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
for ((paramName, paramValue) <- xgboostParams) {
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
params.find(_.name == name) match {
case None =>
case Some(_: DoubleParam) =>
set(name, paramValue.toString.toDouble)
case Some(_: BooleanParam) =>
set(name, paramValue.toString.toBoolean)
case Some(_: IntParam) =>
set(name, paramValue.toString.toInt)
case Some(_: FloatParam) =>
set(name, paramValue.toString.toFloat)
case Some(_: Param[_]) =>
set(name, paramValue)
}
}
}
def MLlib2XGBoostParams: Map[String, Any] = {
val xgboostParams = new mutable.HashMap[String, Any]()
for (param <- params) {
if (isDefined(param)) {
val name = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, param.name)
xgboostParams += name -> $(param)
}
}
xgboostParams.toMap
}
}

View File

@ -20,76 +20,70 @@ import scala.collection.immutable.HashSet
import org.apache.spark.ml.param._
trait LearningTaskParams extends Params {
/**
* number of tasks to learn
*/
val numClasses = new IntParam(this, "num_class", "number of classes")
private[spark] trait LearningTaskParams extends Params {
/**
* Specify the learning task and the corresponding learning objective.
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
* multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:linear
*/
val objective = new Param[String](this, "objective", "objective function used for training," +
s" options: {${LearningTaskParams.supportedObjective.mkString(",")}",
final val objective = new Param[String](this, "objective", "objective function used for " +
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
(value: String) => LearningTaskParams.supportedObjective.contains(value))
final def getObjective: String = $(objective)
/**
* the initial prediction score of all instances, global bias. default=0.5
*/
val baseScore = new DoubleParam(this, "base_score", "the initial prediction score of all" +
final val baseScore = new DoubleParam(this, "baseScore", "the initial prediction score of all" +
" instances, global bias")
final def getBaseScore: Double = $(baseScore)
/**
* evaluation metrics for validation data, a default metric will be assigned according to
* objective(rmse for regression, and error for classification, mean average precision for
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, aucpr, ndcg, map,
* gamma-deviance
*/
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
" data, a default metric will be assigned according to objective (rmse for regression, and" +
" error for classification, mean average precision for ranking), options: " +
s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
"validation data, a default metric will be assigned according to objective " +
"(rmse for regression, and error for classification, mean average precision for ranking), " +
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
final def getEvalMetric: String = $(evalMetric)
/**
* group data specify each group sizes for ranking task. To correspond to partition of
* training data, it is nested.
*/
val groupData = new GroupDataParam(this, "groupData", "group data specify each group size" +
" for ranking task. To correspond to partition of training data, it is nested.")
/**
* Initial prediction (aka base margin) column name.
*/
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
/**
* Instance weights column name.
*/
val weightCol = new Param[String](this, "weightCol", "weight column name")
final val groupData = new GroupDataParam(this, "groupData", "group data specify each group " +
"size for ranking task. To correspond to partition of training data, it is nested.")
/**
* Fraction of training points to use for testing.
*/
val trainTestRatio = new DoubleParam(this, "trainTestRatio",
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
"fraction of training points to use for testing",
ParamValidators.inRange(0, 1))
final def getTrainTestRatio: Double = $(trainTestRatio)
/**
* If non-zero, the training will be stopped after a specified number
* of consecutive increases in any evaluation metric.
*/
val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
"number of rounds of decreasing eval metric to tolerate before " +
"stopping the training",
(value: Int) => value == 0 || value > 1)
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
baseMarginCol -> "baseMargin", weightCol -> "weight", trainTestRatio -> 1.0,
numEarlyStoppingRounds -> 0)
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
setDefault(objective -> "reg:linear", baseScore -> 0.5, groupData -> null,
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
}
private[spark] object LearningTaskParams {

View File

@ -1,75 +0,0 @@
0 1:985.574005058 2:320.223538037 3:0.621236086198
0 1:1010.52917943 2:635.535543082 3:2.14984030531
0 1:1012.91900422 2:132.387300057 3:0.488761066665
0 1:990.829194034 2:135.102081162 3:0.747701610673
0 1:1007.05103629 2:154.289183562 3:0.464118249201
0 1:994.9573036 2:317.483732878 3:0.0313685555674
0 1:987.8071541 2:731.349178363 3:0.244616944245
1 1:10.0349544469 2:2.29750906143 3:36.4949974282
0 1:9.92953881383 2:5.39134047297 3:120.041297548
0 1:10.0909866713 2:9.06191026312 3:138.807825798
1 1:10.2090970614 2:0.0784495944448 3:58.207703565
0 1:9.85695905893 2:9.99500727713 3:56.8610243778
1 1:10.0805758547 2:0.0410805760559 3:222.102302076
0 1:10.1209914486 2:9.9729127088 3:171.888238763
0 1:10.0331939798 2:0.853339303793 3:311.181328375
0 1:9.93901762951 2:2.72757449146 3:78.4859514413
0 1:10.0752365346 2:9.18695328235 3:49.8520256553
1 1:10.0456548902 2:0.270936043122 3:123.462958597
0 1:10.0568923673 2:0.82997113263 3:44.9391426001
0 1:9.8214143472 2:0.277538931578 3:15.4217659578
0 1:9.95258604431 2:8.69564346094 3:255.513470671
0 1:9.91934976357 2:7.72809741413 3:82.171591817
0 1:10.043239582 2:8.64168255553 3:38.9657919329
1 1:10.0236147929 2:0.0496662263659 3:4.40889812286
1 1:1001.85585324 2:3.75646886071 3:0.0179224994842
0 1:1014.25578571 2:0.285765311201 3:0.510329864983
1 1:1002.81422786 2:9.77676280375 3:0.433705951912
1 1:998.072711553 2:2.82100686538 3:0.889829076909
0 1:1003.77395036 2:2.55916592114 3:0.0359402151496
1 1:10.0807877782 2:4.98513959013 3:47.5266363559
0 1:10.0015013081 2:9.94302478763 3:78.3697486277
1 1:10.0441936789 2:0.305091816635 3:56.8213984987
0 1:9.94257106618 2:7.23909568913 3:442.463339039
1 1:9.86479307916 2:6.41701315844 3:55.1365304834
0 1:10.0428628516 2:9.98466447697 3:0.391632812588
0 1:9.94445884566 2:9.99970945878 3:260.438436534
1 1:9.84641392823 2:225.78051312 3:1.00525978847
1 1:9.86907690608 2:26.8971083147 3:0.577959255991
0 1:10.0177314626 2:0.110585342313 3:2.30545043031
0 1:10.0688190907 2:412.023866234 3:1.22421542264
0 1:10.1251769646 2:13.8212202925 3:0.129171734504
0 1:10.0840758802 2:407.359097187 3:0.477000870705
0 1:10.1007458705 2:987.183625145 3:0.149385677415
0 1:9.86472656059 2:169.559640615 3:0.147221652519
0 1:9.94207419238 2:507.290053755 3:0.41996207214
0 1:9.9671005502 2:1.62610457716 3:0.408173666788
0 1:1010.57126596 2:9.06673707562 3:0.672092284372
0 1:1001.6718262 2:9.53203990055 3:4.7364050044
0 1:995.777341384 2:4.43847316256 3:2.07229073634
0 1:1002.95701386 2:5.51711016665 3:1.24294450546
0 1:1016.0988238 2:0.626468941906 3:0.105627919134
0 1:1013.67571419 2:0.042315529666 3:0.717619310322
1 1:994.747747892 2:6.01989364024 3:0.772910130015
1 1:991.654593872 2:7.35575736952 3:1.19822091548
0 1:1008.47101732 2:8.28240754909 3:0.229582481359
0 1:1000.81975227 2:1.52448354056 3:0.096441660362
0 1:10.0900922344 2:322.656649307 3:57.8149073088
1 1:10.0868337371 2:2.88652339174 3:54.8865514572
0 1:10.0988984137 2:979.483832657 3:52.6809830901
0 1:9.97678959238 2:665.770979738 3:481.069628909
0 1:9.78554312773 2:257.309358658 3:47.7324475232
0 1:10.0985967566 2:935.896512941 3:138.937052808
0 1:10.0522252319 2:876.376299607 3:6.00373510669
1 1:9.88065229501 2:9.99979825653 3:0.0674603696149
0 1:10.0483244098 2:0.0653852316381 3:0.130679349938
1 1:9.99685215607 2:1.76602542774 3:0.2551321159
0 1:9.99750159428 2:1.01591534436 3:0.145445506504
1 1:9.97380908941 2:0.940048645571 3:0.411805696316
0 1:9.99977678382 2:6.91329929641 3:5.57858201258
0 1:978.876096381 2:933.775364741 3:0.579170824236
0 1:998.381016406 2:220.940470582 3:2.01491778565
0 1:987.917644594 2:8.74667873567 3:0.364006099758
0 1:1000.20994892 2:25.2945450565 3:3.5684398964
0 1:1014.57141264 2:675.593540733 3:0.164174055535
0 1:998.867283535 2:765.452750642 3:0.818425293238

View File

@ -1,10 +0,0 @@
7
7
10
5
7
10
10
7
6
6

View File

@ -1,74 +0,0 @@
0 1:10.2143092481 2:273.576539531 3:137.111774354
0 1:10.0366658918 2:842.469052609 3:2.32134375927
0 1:10.1281202091 2:395.654057342 3:35.4184893063
0 1:10.1443721289 2:960.058461049 3:272.887070637
0 1:10.1353234784 2:535.51304462 3:2.15393842032
1 1:10.0451640374 2:216.733858424 3:55.6533298016
1 1:9.94254592171 2:44.5985537358 3:304.614176871
0 1:10.1319257181 2:613.545504487 3:5.42391587912
0 1:1020.63622468 2:997.476744201 3:0.509425590461
0 1:986.304585519 2:822.669937965 3:0.605133561808
1 1:1012.66863221 2:26.7185759069 3:0.0875458784828
0 1:995.387656321 2:81.8540176995 3:0.691999430068
0 1:1020.6587198 2:848.826964547 3:0.540159430526
1 1:1003.81573853 2:379.84350931 3:0.0083682925194
0 1:1021.60921516 2:641.376951467 3:1.12339054807
0 1:1000.17585041 2:122.107138713 3:1.09906375372
1 1:987.64802348 2:5.98448541152 3:0.124241987204
1 1:9.94610136583 2:346.114985897 3:0.387708236565
0 1:9.96812192337 2:313.278109696 3:0.00863026595671
0 1:10.0181739194 2:36.7378924562 3:2.92179879835
0 1:9.89000102695 2:164.273723971 3:0.685222591968
0 1:10.1555212436 2:320.451459462 3:2.01341536261
0 1:10.0085727613 2:999.767117646 3:0.462294934168
1 1:9.93099658724 2:5.17478203909 3:0.213855205032
0 1:10.0629454957 2:663.088181857 3:0.049022351462
0 1:10.1109732417 2:734.904569784 3:1.6998450094
0 1:1006.6015266 2:505.023453703 3:1.90870566777
0 1:991.865769489 2:245.437343115 3:0.475109744256
0 1:998.682734072 2:950.041057232 3:1.9256314201
0 1:1005.02207209 2:2.9619314197 3:0.0517146822357
0 1:1002.54526214 2:860.562681899 3:0.915687092848
0 1:1000.38847359 2:808.416525088 3:0.209690673808
1 1:992.557818382 2:373.889409453 3:0.107571728577
0 1:1002.07722137 2:997.329626371 3:1.06504260496
0 1:1000.40504333 2:949.832139189 3:0.539159980327
0 1:10.1460179902 2:8.86082969819 3:135.953842715
1 1:9.98529296553 2:2.87366448495 3:1.74249892194
0 1:9.88942676744 2:9.4031821056 3:149.473066381
1 1:10.0192953341 2:1.99685737576 3:1.79502473397
0 1:10.0110654379 2:8.13112593726 3:87.7765628103
0 1:997.148677047 2:733.936190093 3:1.49298494242
0 1:1008.70465919 2:957.121652078 3:0.217414013634
1 1:997.356154278 2:541.599587807 3:0.100855972216
0 1:999.615897283 2:943.700501824 3:0.862874175879
1 1:997.36859077 2:0.200859940848 3:0.13601892182
0 1:10.0423255624 2:1.73855202168 3:0.956695338485
1 1:9.88440755486 2:9.9994600678 3:0.305080529665
0 1:10.0891026412 2:3.28031719474 3:0.364450973697
0 1:9.90078644258 2:8.77839663617 3:0.456660574479
1 1:9.79380029711 2:8.77220326156 3:0.527292005175
0 1:9.93613887011 2:9.76270841268 3:1.40865693823
0 1:10.0009239007 2:7.29056178263 3:0.498015866607
0 1:9.96603319905 2:5.12498000925 3:0.517492532783
0 1:10.0923827222 2:2.76652583955 3:1.56571226159
1 1:10.0983782035 2:587.788120694 3:0.031756483687
1 1:9.91397225464 2:994.527496819 3:3.72092164978
0 1:10.1057472738 2:2.92894440088 3:0.683506438532
0 1:10.1014053354 2:959.082038017 3:1.07039624129
0 1:10.1433253044 2:322.515119317 3:0.51408278993
1 1:9.82832510699 2:637.104433908 3:0.250272776427
0 1:1000.49729075 2:2.75336888111 3:0.576634423274
1 1:984.90338088 2:0.0295435794035 3:1.26273339929
0 1:1001.53811442 2:4.64164410861 3:0.0293389959504
1 1:995.875898395 2:5.08223403205 3:0.382330566779
0 1:996.405937252 2:6.26395190757 3:0.453645816611
0 1:10.0165140779 2:340.126072514 3:0.220794603312
0 1:9.93482824816 2:951.672000448 3:0.124406293612
0 1:10.1700278554 2:0.0140985961008 3:0.252452256311
0 1:9.99825079542 2:950.382643896 3:0.875382402062
0 1:9.87316410028 2:686.788257829 3:0.215886999825
0 1:10.2893240654 2:89.3947931451 3:0.569578232133
0 1:9.98689192703 2:0.430107535413 3:2.99869831728
0 1:10.1365175107 2:972.279245093 3:0.0865099386744
0 1:9.90744703306 2:50.810461183 3:3.00863325197

View File

@ -1,10 +0,0 @@
8
9
9
9
5
5
9
6
5
9

View File

@ -1,10 +0,0 @@
7
5
9
6
6
8
7
6
5
7

View File

@ -0,0 +1,66 @@
0,10.0229017899,7.30178495562,0.118115020017,1
0,9.93639621859,9.93102159291,0.0435030004396,1
0,10.1301737265,0.00411765220572,2.4165878053,1
1,9.87828587087,0.608588414992,0.111262590883,1
0,10.1373430048,0.47764012225,0.991553052194,1
0,10.0523814718,4.72152505167,0.672978832666,1
0,10.0449715742,8.40373928536,0.384457573667,1
1,996.398498791,941.976309154,0.230269231292,2
0,1005.11269468,900.093680877,0.265031528873,2
0,997.160349441,891.331101688,2.19362017313,2
0,993.754139031,44.8000165317,1.03868009875,2
1,994.831299184,241.959208453,0.667631827024,2
0,995.948333283,7.94326917112,0.750490877118,3
0,989.733981273,7.52077625436,0.0126335967282,3
0,1003.54086516,6.48177510564,1.19441696788,3
0,996.56177804,9.71959812613,1.33082465111,3
0,1005.61382467,0.234339369309,1.17987797356,3
1,980.215758708,6.85554542926,2.63965085259,3
1,987.776408872,2.23354609991,0.841885278028,3
0,1006.54260396,8.12142049834,2.26639471174,3
0,1009.87927639,6.40028519044,0.775155669615,3
0,9.95006244393,928.76896718,234.948458244,4
1,10.0749152258,255.294574476,62.9728604166,4
1,10.1916541988,312.682867085,92.299413677,4
0,9.95646724484,742.263188416,53.3310473654,4
0,9.86211293222,996.237023866,2.00760301168,4
1,9.91801019468,303.971783709,50.3147230679,4
0,996.983996934,9.52188222766,1.33588120981,5
0,995.704388126,9.49260524915,0.908498516541,5
0,987.86480767,0.0870786716821,0.108859297837,5
0,1000.99561307,2.85272694575,0.171134518956,5
0,1011.05508066,7.55336771768,1.04950084825,5
1,985.52199365,0.763305780608,1.7402424375,5
0,10.0430321467,813.185427181,4.97728254185,6
0,10.0812334228,258.297288417,0.127477670549,6
0,9.84210504292,887.205815261,0.991689193955,6
1,9.94625332613,0.298622762132,0.147881353231,6
0,9.97800659954,727.619819757,0.0718361141866,6
1,9.8037938472,957.385549617,0.0618862028941,6
0,10.0880634741,185.024638577,1.7028095095,6
0,9.98630799154,109.10631473,0.681117359751,6
0,9.91671416638,166.248076588,122.538291094,7
0,10.1206910464,88.1539468531,141.189859069,7
1,10.1767160518,1.02960996847,172.02256237,7
0,9.93025147233,391.196641942,58.040338247,7
0,9.84850936037,474.63346537,17.5627875397,7
1,9.8162731343,61.9199554213,30.6740972851,7
0,10.0403482984,987.50416929,73.0472906209,7
1,997.019228359,133.294717663,0.0572254083186,8
0,973.303999107,1.79080888849,0.100478717048,8
0,1008.28808825,342.282350685,0.409806485495,8
0,1014.55621524,0.680510407082,0.929530602495,8
1,1012.74370325,823.105266455,0.0894693730585,8
0,1003.63554038,727.334432075,0.58206275756,8
0,10.1560432436,740.35938307,11.6823378533,9
0,9.83949099701,512.828227154,138.206666681,9
1,10.1837395682,179.287126088,185.479062365,9
1,9.9761881495,12.1093388336,9.1264604171,9
1,9.77402180766,318.561317743,80.6005221355,9
0,1011.15705381,0.215825852155,1.34429667906,10
0,1005.60353229,727.202346126,1.47146041005,10
1,1013.93702961,58.7312725205,0.421041560754,10
0,1004.86813074,757.693204258,0.566055205344,10
0,999.996324692,813.12386828,0.864428279513,10
0,996.55255931,918.760056995,0.43365051974,10
1,1004.1394132,464.371823646,0.312492288321,10
1 0 10.0229017899 7.30178495562 0.118115020017 1
2 0 9.93639621859 9.93102159291 0.0435030004396 1
3 0 10.1301737265 0.00411765220572 2.4165878053 1
4 1 9.87828587087 0.608588414992 0.111262590883 1
5 0 10.1373430048 0.47764012225 0.991553052194 1
6 0 10.0523814718 4.72152505167 0.672978832666 1
7 0 10.0449715742 8.40373928536 0.384457573667 1
8 1 996.398498791 941.976309154 0.230269231292 2
9 0 1005.11269468 900.093680877 0.265031528873 2
10 0 997.160349441 891.331101688 2.19362017313 2
11 0 993.754139031 44.8000165317 1.03868009875 2
12 1 994.831299184 241.959208453 0.667631827024 2
13 0 995.948333283 7.94326917112 0.750490877118 3
14 0 989.733981273 7.52077625436 0.0126335967282 3
15 0 1003.54086516 6.48177510564 1.19441696788 3
16 0 996.56177804 9.71959812613 1.33082465111 3
17 0 1005.61382467 0.234339369309 1.17987797356 3
18 1 980.215758708 6.85554542926 2.63965085259 3
19 1 987.776408872 2.23354609991 0.841885278028 3
20 0 1006.54260396 8.12142049834 2.26639471174 3
21 0 1009.87927639 6.40028519044 0.775155669615 3
22 0 9.95006244393 928.76896718 234.948458244 4
23 1 10.0749152258 255.294574476 62.9728604166 4
24 1 10.1916541988 312.682867085 92.299413677 4
25 0 9.95646724484 742.263188416 53.3310473654 4
26 0 9.86211293222 996.237023866 2.00760301168 4
27 1 9.91801019468 303.971783709 50.3147230679 4
28 0 996.983996934 9.52188222766 1.33588120981 5
29 0 995.704388126 9.49260524915 0.908498516541 5
30 0 987.86480767 0.0870786716821 0.108859297837 5
31 0 1000.99561307 2.85272694575 0.171134518956 5
32 0 1011.05508066 7.55336771768 1.04950084825 5
33 1 985.52199365 0.763305780608 1.7402424375 5
34 0 10.0430321467 813.185427181 4.97728254185 6
35 0 10.0812334228 258.297288417 0.127477670549 6
36 0 9.84210504292 887.205815261 0.991689193955 6
37 1 9.94625332613 0.298622762132 0.147881353231 6
38 0 9.97800659954 727.619819757 0.0718361141866 6
39 1 9.8037938472 957.385549617 0.0618862028941 6
40 0 10.0880634741 185.024638577 1.7028095095 6
41 0 9.98630799154 109.10631473 0.681117359751 6
42 0 9.91671416638 166.248076588 122.538291094 7
43 0 10.1206910464 88.1539468531 141.189859069 7
44 1 10.1767160518 1.02960996847 172.02256237 7
45 0 9.93025147233 391.196641942 58.040338247 7
46 0 9.84850936037 474.63346537 17.5627875397 7
47 1 9.8162731343 61.9199554213 30.6740972851 7
48 0 10.0403482984 987.50416929 73.0472906209 7
49 1 997.019228359 133.294717663 0.0572254083186 8
50 0 973.303999107 1.79080888849 0.100478717048 8
51 0 1008.28808825 342.282350685 0.409806485495 8
52 0 1014.55621524 0.680510407082 0.929530602495 8
53 1 1012.74370325 823.105266455 0.0894693730585 8
54 0 1003.63554038 727.334432075 0.58206275756 8
55 0 10.1560432436 740.35938307 11.6823378533 9
56 0 9.83949099701 512.828227154 138.206666681 9
57 1 10.1837395682 179.287126088 185.479062365 9
58 1 9.9761881495 12.1093388336 9.1264604171 9
59 1 9.77402180766 318.561317743 80.6005221355 9
60 0 1011.15705381 0.215825852155 1.34429667906 10
61 0 1005.60353229 727.202346126 1.47146041005 10
62 1 1013.93702961 58.7312725205 0.421041560754 10
63 0 1004.86813074 757.693204258 0.566055205344 10
64 0 999.996324692 813.12386828 0.864428279513 10
65 0 996.55255931 918.760056995 0.43365051974 10
66 1 1004.1394132 464.371823646 0.312492288321 10

View File

@ -0,0 +1,149 @@
0,985.574005058,320.223538037,0.621236086198,1
0,1010.52917943,635.535543082,2.14984030531,1
0,1012.91900422,132.387300057,0.488761066665,1
0,990.829194034,135.102081162,0.747701610673,1
0,1007.05103629,154.289183562,0.464118249201,1
0,994.9573036,317.483732878,0.0313685555674,1
0,987.8071541,731.349178363,0.244616944245,1
1,10.0349544469,2.29750906143,36.4949974282,2
0,9.92953881383,5.39134047297,120.041297548,2
0,10.0909866713,9.06191026312,138.807825798,2
1,10.2090970614,0.0784495944448,58.207703565,2
0,9.85695905893,9.99500727713,56.8610243778,2
1,10.0805758547,0.0410805760559,222.102302076,2
0,10.1209914486,9.9729127088,171.888238763,2
0,10.0331939798,0.853339303793,311.181328375,3
0,9.93901762951,2.72757449146,78.4859514413,3
0,10.0752365346,9.18695328235,49.8520256553,3
1,10.0456548902,0.270936043122,123.462958597,3
0,10.0568923673,0.82997113263,44.9391426001,3
0,9.8214143472,0.277538931578,15.4217659578,3
0,9.95258604431,8.69564346094,255.513470671,3
0,9.91934976357,7.72809741413,82.171591817,3
0,10.043239582,8.64168255553,38.9657919329,3
1,10.0236147929,0.0496662263659,4.40889812286,3
1,1001.85585324,3.75646886071,0.0179224994842,4
0,1014.25578571,0.285765311201,0.510329864983,4
1,1002.81422786,9.77676280375,0.433705951912,4
1,998.072711553,2.82100686538,0.889829076909,4
0,1003.77395036,2.55916592114,0.0359402151496,4
1,10.0807877782,4.98513959013,47.5266363559,5
0,10.0015013081,9.94302478763,78.3697486277,5
1,10.0441936789,0.305091816635,56.8213984987,5
0,9.94257106618,7.23909568913,442.463339039,5
1,9.86479307916,6.41701315844,55.1365304834,5
0,10.0428628516,9.98466447697,0.391632812588,5
0,9.94445884566,9.99970945878,260.438436534,5
1,9.84641392823,225.78051312,1.00525978847,6
1,9.86907690608,26.8971083147,0.577959255991,6
0,10.0177314626,0.110585342313,2.30545043031,6
0,10.0688190907,412.023866234,1.22421542264,6
0,10.1251769646,13.8212202925,0.129171734504,6
0,10.0840758802,407.359097187,0.477000870705,6
0,10.1007458705,987.183625145,0.149385677415,6
0,9.86472656059,169.559640615,0.147221652519,6
0,9.94207419238,507.290053755,0.41996207214,6
0,9.9671005502,1.62610457716,0.408173666788,6
0,1010.57126596,9.06673707562,0.672092284372,7
0,1001.6718262,9.53203990055,4.7364050044,7
0,995.777341384,4.43847316256,2.07229073634,7
0,1002.95701386,5.51711016665,1.24294450546,7
0,1016.0988238,0.626468941906,0.105627919134,7
0,1013.67571419,0.042315529666,0.717619310322,7
1,994.747747892,6.01989364024,0.772910130015,7
1,991.654593872,7.35575736952,1.19822091548,7
0,1008.47101732,8.28240754909,0.229582481359,7
0,1000.81975227,1.52448354056,0.096441660362,7
0,10.0900922344,322.656649307,57.8149073088,8
1,10.0868337371,2.88652339174,54.8865514572,8
0,10.0988984137,979.483832657,52.6809830901,8
0,9.97678959238,665.770979738,481.069628909,8
0,9.78554312773,257.309358658,47.7324475232,8
0,10.0985967566,935.896512941,138.937052808,8
0,10.0522252319,876.376299607,6.00373510669,8
1,9.88065229501,9.99979825653,0.0674603696149,9
0,10.0483244098,0.0653852316381,0.130679349938,9
1,9.99685215607,1.76602542774,0.2551321159,9
0,9.99750159428,1.01591534436,0.145445506504,9
1,9.97380908941,0.940048645571,0.411805696316,9
0,9.99977678382,6.91329929641,5.57858201258,9
0,978.876096381,933.775364741,0.579170824236,10
0,998.381016406,220.940470582,2.01491778565,10
0,987.917644594,8.74667873567,0.364006099758,10
0,1000.20994892,25.2945450565,3.5684398964,10
0,1014.57141264,675.593540733,0.164174055535,10
0,998.867283535,765.452750642,0.818425293238,10
0,10.2143092481,273.576539531,137.111774354,11
0,10.0366658918,842.469052609,2.32134375927,11
0,10.1281202091,395.654057342,35.4184893063,11
0,10.1443721289,960.058461049,272.887070637,11
0,10.1353234784,535.51304462,2.15393842032,11
1,10.0451640374,216.733858424,55.6533298016,11
1,9.94254592171,44.5985537358,304.614176871,11
0,10.1319257181,613.545504487,5.42391587912,11
0,1020.63622468,997.476744201,0.509425590461,12
0,986.304585519,822.669937965,0.605133561808,12
1,1012.66863221,26.7185759069,0.0875458784828,12
0,995.387656321,81.8540176995,0.691999430068,12
0,1020.6587198,848.826964547,0.540159430526,12
1,1003.81573853,379.84350931,0.0083682925194,12
0,1021.60921516,641.376951467,1.12339054807,12
0,1000.17585041,122.107138713,1.09906375372,12
1,987.64802348,5.98448541152,0.124241987204,12
1,9.94610136583,346.114985897,0.387708236565,13
0,9.96812192337,313.278109696,0.00863026595671,13
0,10.0181739194,36.7378924562,2.92179879835,13
0,9.89000102695,164.273723971,0.685222591968,13
0,10.1555212436,320.451459462,2.01341536261,13
0,10.0085727613,999.767117646,0.462294934168,13
1,9.93099658724,5.17478203909,0.213855205032,13
0,10.0629454957,663.088181857,0.049022351462,13
0,10.1109732417,734.904569784,1.6998450094,13
0,1006.6015266,505.023453703,1.90870566777,14
0,991.865769489,245.437343115,0.475109744256,14
0,998.682734072,950.041057232,1.9256314201,14
0,1005.02207209,2.9619314197,0.0517146822357,14
0,1002.54526214,860.562681899,0.915687092848,14
0,1000.38847359,808.416525088,0.209690673808,14
1,992.557818382,373.889409453,0.107571728577,14
0,1002.07722137,997.329626371,1.06504260496,14
0,1000.40504333,949.832139189,0.539159980327,14
0,10.1460179902,8.86082969819,135.953842715,15
1,9.98529296553,2.87366448495,1.74249892194,15
0,9.88942676744,9.4031821056,149.473066381,15
1,10.0192953341,1.99685737576,1.79502473397,15
0,10.0110654379,8.13112593726,87.7765628103,15
0,997.148677047,733.936190093,1.49298494242,16
0,1008.70465919,957.121652078,0.217414013634,16
1,997.356154278,541.599587807,0.100855972216,16
0,999.615897283,943.700501824,0.862874175879,16
1,997.36859077,0.200859940848,0.13601892182,16
0,10.0423255624,1.73855202168,0.956695338485,17
1,9.88440755486,9.9994600678,0.305080529665,17
0,10.0891026412,3.28031719474,0.364450973697,17
0,9.90078644258,8.77839663617,0.456660574479,17
1,9.79380029711,8.77220326156,0.527292005175,17
0,9.93613887011,9.76270841268,1.40865693823,17
0,10.0009239007,7.29056178263,0.498015866607,17
0,9.96603319905,5.12498000925,0.517492532783,17
0,10.0923827222,2.76652583955,1.56571226159,17
1,10.0983782035,587.788120694,0.031756483687,18
1,9.91397225464,994.527496819,3.72092164978,18
0,10.1057472738,2.92894440088,0.683506438532,18
0,10.1014053354,959.082038017,1.07039624129,18
0,10.1433253044,322.515119317,0.51408278993,18
1,9.82832510699,637.104433908,0.250272776427,18
0,1000.49729075,2.75336888111,0.576634423274,19
1,984.90338088,0.0295435794035,1.26273339929,19
0,1001.53811442,4.64164410861,0.0293389959504,19
1,995.875898395,5.08223403205,0.382330566779,19
0,996.405937252,6.26395190757,0.453645816611,19
0,10.0165140779,340.126072514,0.220794603312,20
0,9.93482824816,951.672000448,0.124406293612,20
0,10.1700278554,0.0140985961008,0.252452256311,20
0,9.99825079542,950.382643896,0.875382402062,20
0,9.87316410028,686.788257829,0.215886999825,20
0,10.2893240654,89.3947931451,0.569578232133,20
0,9.98689192703,0.430107535413,2.99869831728,20
0,10.1365175107,972.279245093,0.0865099386744,20
0,9.90744703306,50.810461183,3.00863325197,20
1 0 985.574005058 320.223538037 0.621236086198 1
2 0 1010.52917943 635.535543082 2.14984030531 1
3 0 1012.91900422 132.387300057 0.488761066665 1
4 0 990.829194034 135.102081162 0.747701610673 1
5 0 1007.05103629 154.289183562 0.464118249201 1
6 0 994.9573036 317.483732878 0.0313685555674 1
7 0 987.8071541 731.349178363 0.244616944245 1
8 1 10.0349544469 2.29750906143 36.4949974282 2
9 0 9.92953881383 5.39134047297 120.041297548 2
10 0 10.0909866713 9.06191026312 138.807825798 2
11 1 10.2090970614 0.0784495944448 58.207703565 2
12 0 9.85695905893 9.99500727713 56.8610243778 2
13 1 10.0805758547 0.0410805760559 222.102302076 2
14 0 10.1209914486 9.9729127088 171.888238763 2
15 0 10.0331939798 0.853339303793 311.181328375 3
16 0 9.93901762951 2.72757449146 78.4859514413 3
17 0 10.0752365346 9.18695328235 49.8520256553 3
18 1 10.0456548902 0.270936043122 123.462958597 3
19 0 10.0568923673 0.82997113263 44.9391426001 3
20 0 9.8214143472 0.277538931578 15.4217659578 3
21 0 9.95258604431 8.69564346094 255.513470671 3
22 0 9.91934976357 7.72809741413 82.171591817 3
23 0 10.043239582 8.64168255553 38.9657919329 3
24 1 10.0236147929 0.0496662263659 4.40889812286 3
25 1 1001.85585324 3.75646886071 0.0179224994842 4
26 0 1014.25578571 0.285765311201 0.510329864983 4
27 1 1002.81422786 9.77676280375 0.433705951912 4
28 1 998.072711553 2.82100686538 0.889829076909 4
29 0 1003.77395036 2.55916592114 0.0359402151496 4
30 1 10.0807877782 4.98513959013 47.5266363559 5
31 0 10.0015013081 9.94302478763 78.3697486277 5
32 1 10.0441936789 0.305091816635 56.8213984987 5
33 0 9.94257106618 7.23909568913 442.463339039 5
34 1 9.86479307916 6.41701315844 55.1365304834 5
35 0 10.0428628516 9.98466447697 0.391632812588 5
36 0 9.94445884566 9.99970945878 260.438436534 5
37 1 9.84641392823 225.78051312 1.00525978847 6
38 1 9.86907690608 26.8971083147 0.577959255991 6
39 0 10.0177314626 0.110585342313 2.30545043031 6
40 0 10.0688190907 412.023866234 1.22421542264 6
41 0 10.1251769646 13.8212202925 0.129171734504 6
42 0 10.0840758802 407.359097187 0.477000870705 6
43 0 10.1007458705 987.183625145 0.149385677415 6
44 0 9.86472656059 169.559640615 0.147221652519 6
45 0 9.94207419238 507.290053755 0.41996207214 6
46 0 9.9671005502 1.62610457716 0.408173666788 6
47 0 1010.57126596 9.06673707562 0.672092284372 7
48 0 1001.6718262 9.53203990055 4.7364050044 7
49 0 995.777341384 4.43847316256 2.07229073634 7
50 0 1002.95701386 5.51711016665 1.24294450546 7
51 0 1016.0988238 0.626468941906 0.105627919134 7
52 0 1013.67571419 0.042315529666 0.717619310322 7
53 1 994.747747892 6.01989364024 0.772910130015 7
54 1 991.654593872 7.35575736952 1.19822091548 7
55 0 1008.47101732 8.28240754909 0.229582481359 7
56 0 1000.81975227 1.52448354056 0.096441660362 7
57 0 10.0900922344 322.656649307 57.8149073088 8
58 1 10.0868337371 2.88652339174 54.8865514572 8
59 0 10.0988984137 979.483832657 52.6809830901 8
60 0 9.97678959238 665.770979738 481.069628909 8
61 0 9.78554312773 257.309358658 47.7324475232 8
62 0 10.0985967566 935.896512941 138.937052808 8
63 0 10.0522252319 876.376299607 6.00373510669 8
64 1 9.88065229501 9.99979825653 0.0674603696149 9
65 0 10.0483244098 0.0653852316381 0.130679349938 9
66 1 9.99685215607 1.76602542774 0.2551321159 9
67 0 9.99750159428 1.01591534436 0.145445506504 9
68 1 9.97380908941 0.940048645571 0.411805696316 9
69 0 9.99977678382 6.91329929641 5.57858201258 9
70 0 978.876096381 933.775364741 0.579170824236 10
71 0 998.381016406 220.940470582 2.01491778565 10
72 0 987.917644594 8.74667873567 0.364006099758 10
73 0 1000.20994892 25.2945450565 3.5684398964 10
74 0 1014.57141264 675.593540733 0.164174055535 10
75 0 998.867283535 765.452750642 0.818425293238 10
76 0 10.2143092481 273.576539531 137.111774354 11
77 0 10.0366658918 842.469052609 2.32134375927 11
78 0 10.1281202091 395.654057342 35.4184893063 11
79 0 10.1443721289 960.058461049 272.887070637 11
80 0 10.1353234784 535.51304462 2.15393842032 11
81 1 10.0451640374 216.733858424 55.6533298016 11
82 1 9.94254592171 44.5985537358 304.614176871 11
83 0 10.1319257181 613.545504487 5.42391587912 11
84 0 1020.63622468 997.476744201 0.509425590461 12
85 0 986.304585519 822.669937965 0.605133561808 12
86 1 1012.66863221 26.7185759069 0.0875458784828 12
87 0 995.387656321 81.8540176995 0.691999430068 12
88 0 1020.6587198 848.826964547 0.540159430526 12
89 1 1003.81573853 379.84350931 0.0083682925194 12
90 0 1021.60921516 641.376951467 1.12339054807 12
91 0 1000.17585041 122.107138713 1.09906375372 12
92 1 987.64802348 5.98448541152 0.124241987204 12
93 1 9.94610136583 346.114985897 0.387708236565 13
94 0 9.96812192337 313.278109696 0.00863026595671 13
95 0 10.0181739194 36.7378924562 2.92179879835 13
96 0 9.89000102695 164.273723971 0.685222591968 13
97 0 10.1555212436 320.451459462 2.01341536261 13
98 0 10.0085727613 999.767117646 0.462294934168 13
99 1 9.93099658724 5.17478203909 0.213855205032 13
100 0 10.0629454957 663.088181857 0.049022351462 13
101 0 10.1109732417 734.904569784 1.6998450094 13
102 0 1006.6015266 505.023453703 1.90870566777 14
103 0 991.865769489 245.437343115 0.475109744256 14
104 0 998.682734072 950.041057232 1.9256314201 14
105 0 1005.02207209 2.9619314197 0.0517146822357 14
106 0 1002.54526214 860.562681899 0.915687092848 14
107 0 1000.38847359 808.416525088 0.209690673808 14
108 1 992.557818382 373.889409453 0.107571728577 14
109 0 1002.07722137 997.329626371 1.06504260496 14
110 0 1000.40504333 949.832139189 0.539159980327 14
111 0 10.1460179902 8.86082969819 135.953842715 15
112 1 9.98529296553 2.87366448495 1.74249892194 15
113 0 9.88942676744 9.4031821056 149.473066381 15
114 1 10.0192953341 1.99685737576 1.79502473397 15
115 0 10.0110654379 8.13112593726 87.7765628103 15
116 0 997.148677047 733.936190093 1.49298494242 16
117 0 1008.70465919 957.121652078 0.217414013634 16
118 1 997.356154278 541.599587807 0.100855972216 16
119 0 999.615897283 943.700501824 0.862874175879 16
120 1 997.36859077 0.200859940848 0.13601892182 16
121 0 10.0423255624 1.73855202168 0.956695338485 17
122 1 9.88440755486 9.9994600678 0.305080529665 17
123 0 10.0891026412 3.28031719474 0.364450973697 17
124 0 9.90078644258 8.77839663617 0.456660574479 17
125 1 9.79380029711 8.77220326156 0.527292005175 17
126 0 9.93613887011 9.76270841268 1.40865693823 17
127 0 10.0009239007 7.29056178263 0.498015866607 17
128 0 9.96603319905 5.12498000925 0.517492532783 17
129 0 10.0923827222 2.76652583955 1.56571226159 17
130 1 10.0983782035 587.788120694 0.031756483687 18
131 1 9.91397225464 994.527496819 3.72092164978 18
132 0 10.1057472738 2.92894440088 0.683506438532 18
133 0 10.1014053354 959.082038017 1.07039624129 18
134 0 10.1433253044 322.515119317 0.51408278993 18
135 1 9.82832510699 637.104433908 0.250272776427 18
136 0 1000.49729075 2.75336888111 0.576634423274 19
137 1 984.90338088 0.0295435794035 1.26273339929 19
138 0 1001.53811442 4.64164410861 0.0293389959504 19
139 1 995.875898395 5.08223403205 0.382330566779 19
140 0 996.405937252 6.26395190757 0.453645816611 19
141 0 10.0165140779 340.126072514 0.220794603312 20
142 0 9.93482824816 951.672000448 0.124406293612 20
143 0 10.1700278554 0.0140985961008 0.252452256311 20
144 0 9.99825079542 950.382643896 0.875382402062 20
145 0 9.87316410028 686.788257829 0.215886999825 20
146 0 10.2893240654 89.3947931451 0.569578232133 20
147 0 9.98689192703 0.430107535413 2.99869831728 20
148 0 10.1365175107 972.279245093 0.0865099386744 20
149 0 9.90744703306 50.810461183 3.00863325197 20

View File

@ -21,37 +21,27 @@ import java.nio.file.Files
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.{SparkConf, SparkContext}
class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
var sc: SparkContext = _
override def beforeAll(): Unit = {
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("XGBoostSuite")
sc = new SparkContext(conf)
}
class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAll {
private lazy val (model4, model8) = {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
(XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, nWorkers = sc.defaultParallelism),
XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, nWorkers = sc.defaultParallelism))
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateCheckpoint(model4)
manager.updateCheckpoint(model4._booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
manager.updateCheckpoint(model8)
manager.updateCheckpoint(model8._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
@ -61,7 +51,7 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateCheckpoint(model8)
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(round = 8)
assert(new File(s"$tmpPath/8.model").exists())
@ -74,7 +64,8 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
manager.updateCheckpoint(model4)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
}
}

View File

@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
@transient private var currentSession: SparkSession = _
@ -62,4 +64,30 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
file.delete()
}
}
protected def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features")
}
protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features", "group")
}
}

View File

@ -0,0 +1,167 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException}
import java.util.Arrays
import ml.dmlc.xgboost4j.scala.DMatrix
import scala.util.Random
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.network.util.JavaUtils
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
private var tempDir: File = _
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = new File(System.getProperty("java.io.tmpdir"), this.getClass.getName)
if (tempDir.exists) {
tempDir.delete
}
tempDir.mkdirs
}
override def afterAll(): Unit = {
JavaUtils.deleteRecursively(tempDir)
super.afterAll()
}
private def delete(f: File) {
if (f.exists) {
if (f.isDirectory) {
for (c <- f.listFiles) {
delete(c)
}
}
if (!f.delete) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
}
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val eval = new EvalError()
val trainingDF = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
val xgbc = new XGBoostClassifier(paramMap)
val xgbcPath = new File(tempDir, "xgbc").getPath
xgbc.write.overwrite().save(xgbcPath)
val xgbc2 = XGBoostClassifier.load(xgbcPath)
val paramMap2 = xgbc2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
val model = xgbc.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1)
val xgbcModelPath = new File(tempDir, "xgbcModel").getPath
model.write.overwrite.save(xgbcModelPath)
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
assert(model.getEta === model2.getEta)
assert(model.getNumRound === model2.getNumRound)
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults === evalResults2)
}
test("test persistence of XGBoostRegressor and XGBoostRegressionModel") {
val eval = new EvalError()
val trainingDF = buildDataFrame(Regression.train)
val testDM = new DMatrix(Regression.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> "10", "num_workers" -> numWorkers)
val xgbr = new XGBoostRegressor(paramMap)
val xgbrPath = new File(tempDir, "xgbr").getPath
xgbr.write.overwrite().save(xgbrPath)
val xgbr2 = XGBoostRegressor.load(xgbrPath)
val paramMap2 = xgbr2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
val model = xgbr.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1)
val xgbrModelPath = new File(tempDir, "xgbrModel").getPath
model.write.overwrite.save(xgbrModelPath)
val model2 = XGBoostRegressionModel.load(xgbrModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
assert(model.getEta === model2.getEta)
assert(model.getNumRound === model2.getNumRound)
assert(model.getPredictionCol === model2.getPredictionCol)
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults === evalResults2)
}
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
val r = new Random(0)
// maybe move to shared context, but requires session to import implicits
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val xgb = new XGBoostClassifier(paramMap)
// Construct MLlib pipeline, save and load
val pipeline = new Pipeline().setStages(Array(assembler, xgb))
val pipePath = new File(tempDir, "pipeline").getPath
pipeline.write.overwrite().save(pipePath)
val pipeline2 = Pipeline.read.load(pipePath)
val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
val paramMap2 = xgb2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
// Model training, save and load
val pipeModel = pipeline.fit(df)
val pipeModelPath = new File(tempDir, "pipelineModel").getPath
pipeModel.write.overwrite.save(pipeModelPath)
val pipeModel2 = PipelineModel.load(pipeModelPath)
val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel]
val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel]
assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray))
assert(xgbModel.getEta === xgbModel2.getEta)
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
}
}

View File

@ -16,8 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import scala.io.Source
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
trait TrainTestData {
@ -48,6 +48,17 @@ trait TrainTestData {
XGBLabeledPoint(label, null, values)
}.toList
}
protected def getLabeledPointsWithGroup(resource: String): Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
val original = line.split(",")
val length = original.length
val label = original.head.toFloat
val group = original.last.toInt
val values = original.slice(1, length - 1).map(_.toFloat)
XGBLabeledPoint(label, null, values, 1f, group, Float.NaN)
}.toList
}
}
object Classification extends TrainTestData {
@ -80,11 +91,8 @@ object Regression extends TrainTestData {
}
object Ranking extends TrainTestData {
val train0: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-0.txt.train", zeroBased = false)
val train1: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-1.txt.train", zeroBased = false)
val trainGroup0: Seq[Int] = getGroups("/rank-demo-0.txt.train.group")
val trainGroup1: Seq[Int] = getGroups("/rank-demo-1.txt.train.group")
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo.txt.test", zeroBased = false)
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank.test.txt", zeroBased = false)
private def getGroups(resource: String): Seq[Int] = {
getResourceLines(resource).map(_.toInt).toList

View File

@ -0,0 +1,207 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql._
import org.scalatest.FunSuite
class XGBoostClassifierSuite extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic")
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
assert(testDF.count() === prediction2.size)
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
for (i <- prediction1.indices) {
assert(prediction1(i).length === prediction2(i).values.length - 1)
for (j <- prediction1(i).indices) {
assert(prediction1(i)(j) === prediction2(i)(j + 1))
}
}
val prediction3 = model1.predict(testDM, outPutMargin = true)
val prediction4 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
assert(testDF.count() === prediction4.size)
for (i <- prediction3.indices) {
assert(prediction3(i).length === prediction4(i).values.length)
for (j <- prediction3(i).indices) {
assert(prediction3(i)(j) === prediction4(i)(j))
}
}
}
test("Set params in XGBoost and MLlib way should produce same model") {
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> round,
"num_workers" -> numWorkers)
// Set params in XGBoost way
val model1 = new XGBoostClassifier(paramMap).fit(trainingDF)
// Set params in MLlib way
val model2 = new XGBoostClassifier()
.setEta(1)
.setMaxDepth(6)
.setSilent(1)
.setObjective("binary:logistic")
.setNumRound(round)
.setNumWorkers(numWorkers)
.fit(trainingDF)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val prediction2 = model2.transform(testDF).select("prediction").collect()
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 === p2)
}
}
test("test schema of XGBoostClassificationModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val model = new XGBoostClassifier(paramMap).fit(trainingDF)
model.setRawPredictionCol("raw_prediction")
.setProbabilityCol("probability_prediction")
.setPredictionCol("final_prediction")
var predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("probability_prediction"))
assert(predictionDF.columns.contains("final_prediction"))
model.setRawPredictionCol("").setPredictionCol("final_prediction")
predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("raw_prediction") === false)
assert(predictionDF.columns.contains("final_prediction"))
model.setRawPredictionCol("raw_prediction").setPredictionCol("")
predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("final_prediction") === false)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty)
}
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
// from xgboost params to spark params
val xgb = new XGBoostClassifier(xgbParamMap)
assert(xgb.getEta === 1.0)
assert(xgb.getObjective === "binary:logistic")
// from spark to xgboost params
val xgbCopy = xgb.copy(ParamMap.empty)
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}
test("multi class classification") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
assert(model.getEta == 0.1)
assert(model.getMaxDepth == 6)
assert(model.numClasses == 6)
}
test("use base margin") {
val training1 = buildDataFrame(Classification.train)
val training2 = training1.withColumn("margin", functions.rand())
val test = buildDataFrame(Classification.test)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "test_train_split" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val xgb = new XGBoostClassifier(paramMap)
val model1 = xgb.fit(training1)
val model2 = xgb.setBaseMarginCol("margin").fit(training2)
val prediction1 = model1.transform(test).select(model1.getProbabilityCol)
.collect().map(row => row.getAs[Vector](0))
val prediction2 = model2.transform(test).select(model2.getProbabilityCol)
.collect().map(row => row.getAs[Vector](0))
var count = 0
for ((r1, r2) <- prediction1.zip(prediction2)) {
if (!r1.equals(r2)) count = count + 1
}
assert(count != 0)
}
test("training summary") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
val trainingDF = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty)
}
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
}
}

View File

@ -17,36 +17,34 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.scalatest.FunSuite
class XGBoostConfigureSuite extends FunSuite with PerTest {
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
test("nthread configuration must be no larger than spark.task.cpus") {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic",
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
intercept[IllegalArgumentException] {
XGBoost.trainWithRDD(sc.parallelize(List()), paramMap, 5, numWorkers)
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training)
}
}
test("kryoSerializer test") {
import DataUtils._
// TODO write an isolated test for Booster.
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator, null)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator, null)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val eval = new EvalError()
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
}

View File

@ -1,265 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import org.scalatest.FunSuite
import org.scalatest.prop.TableDrivenPropertyChecks
class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
private def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features")
}
test("test consistency and order preservation of dataframe-based model") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val trainingItr = Classification.train.iterator
val testItr = Classification.test.iterator
val round = 5
val trainDMatrix = new DMatrix(trainingItr)
val testDMatrix = new DMatrix(testItr)
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
val predResultFromSeq = xgboostModel.predict(testDMatrix)
val trainingDF = buildDataFrame(Classification.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = round, nWorkers = numWorkers)
val testDF = buildDataFrame(Classification.test)
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
assert(testDF.count() === predResultsFromDF.size)
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
// Spark
for (i <- predResultFromSeq.indices) {
assert(predResultFromSeq(i).length === predResultsFromDF(i).values.length - 1)
for (j <- predResultFromSeq(i).indices) {
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j + 1))
}
}
}
test("test transformLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val trainingDF = buildDataFrame(Classification.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers)
val testDF = buildDataFrame(Classification.test)
xgBoostModelWithDF.transformLeaf(testDF).show()
}
test("test schema of XGBoostRegressionModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
val trainingDF = buildDataFrame(Regression.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = true)
xgBoostModelWithDF.setPredictionCol("final_prediction")
val testDF = buildDataFrame(Regression.test)
val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("final_prediction"))
predictionDF.show()
}
test("test schema of XGBoostClassificationModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val trainingDF = buildDataFrame(Classification.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = true)
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(
"raw_prediction").setPredictionCol("final_prediction")
val testDF = buildDataFrame(Classification.test)
var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("final_prediction"))
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("").
setPredictionCol("final_prediction")
predictionDF = xgBoostModelWithDF.transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction") === false)
assert(predictionDF.columns.contains("final_prediction"))
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].
setRawPredictionCol("raw_prediction").setPredictionCol("")
predictionDF = xgBoostModelWithDF.transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("final_prediction") === false)
}
test("xgboost and spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
// from xgboost params to spark params
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
assert(xgbEstimator.get(xgbEstimator.eta).get === 1.0)
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
// from spark to xgboost params
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eta").toString.toDouble === 1.0)
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("objective").toString === "binary:logistic")
}
test("eval_metric is configured correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
val sparkParamMap = ParamMap.empty
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eval_metric") === "error")
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
}
ignore("fast histogram algorithm parameters are exposed correctly") {
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
"eval_metric" -> "error")
val testItr = Classification.test.iterator
val trainingDF = buildDataFrame(Classification.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 10, nWorkers = math.min(2, numWorkers))
val error = new EvalError
val testSetDMatrix = new DMatrix(testItr)
assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
}
test("multi_class classification test") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val trainingDF = buildDataFrame(MultiClassification.train)
XGBoost.trainWithDataFrame(trainingDF.toDF(), paramMap, round = 5, nWorkers = numWorkers)
}
test("test DF use nested groupData") {
val trainingDF = buildDataFrame(Ranking.train0, 1)
.union(buildDataFrame(Ranking.train1, 1))
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = 2)
val testDF = buildDataFrame(Ranking.test)
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
assert(testDF.count() === predResultsFromDF.size)
}
test("params of estimator and produced model are coordinated correctly") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val trainingDF = buildDataFrame(MultiClassification.train)
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers)
assert(model.get[Double](model.eta).get == 0.1)
assert(model.get[Int](model.maxDepth).get == 6)
assert(model.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
}
test("test use base margin") {
import DataUtils._
val trainingDf = buildDataFrame(Classification.train)
val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand())
val testRDD = sc.parallelize(Classification.test.map(_.features))
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "baseMarginCol" -> "margin",
"testTrainSplit" -> 0.5)
def trainPredict(df: Dataset[_]): Array[Float] = {
XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers)
.predict(testRDD)
.map { case Array(p) => p }
.collect()
}
val pred = trainPredict(trainingDf)
val predWithMargin = trainPredict(trainingDfWithMargin)
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
}
test("test use weight") {
import DataUtils._
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "weightCol" -> "weight")
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
.setPredictionCol("final_prediction")
.setExternalMemory(true)
val testRDD = sc.parallelize(Regression.test.map(_.features))
val predictions = model.predict(testRDD).collect().flatten
// The predictions heavily relies on the first training instance, and thus are very close.
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
}
test("training summary") {
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val trainingDf = buildDataFrame(Classification.train)
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty)
}
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
val trainingDf = buildDataFrame(Classification.train)
forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = useExternalMemory)
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
}
}
}

View File

@ -18,19 +18,18 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
import org.apache.spark.rdd.RDD
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql._
import org.scalatest.FunSuite
import scala.util.Random
class XGBoostGeneralSuite extends FunSuite with PerTest {
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
@ -87,283 +86,153 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
}
test("training with external memory cache") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"use_external_memory" -> true)
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("training with Scala-implemented Rabit tracker") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
ignore("test with fast histo depthwise") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "eval_metric" -> "error")
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
// TODO: histogram algorithm seems to be very very sensitive to worker number
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = math.min(numWorkers, 2))
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
ignore("test with fast histo lossguide") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = math.min(numWorkers, 2))
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
"max_leaves" -> "8", "eval_metric" -> "error", "num_round" -> 5,
"num_workers" -> math.min(numWorkers, 2))
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
ignore("test with fast histo lossguide with max bin") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
"eval_metric" -> "error")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = math.min(numWorkers, 2))
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
ignore("test with fast histo depthwidth with max depth") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
"eval_metric" -> "error")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
nWorkers = math.min(numWorkers, 2))
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
ignore("test with fast histo depthwidth with max depth and max bin") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
"eval_metric" -> "error")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
nWorkers = math.min(numWorkers, 2))
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
test("test with dense vectors containing missing value") {
def buildDenseRDD(): RDD[MLLabeledPoint] = {
test("dense vectors containing missing value") {
def buildDenseDataFrame(): DataFrame = {
val numRows = 100
val numCols = 5
val labeledPoints = (0 until numRows).map { _ =>
val label = Random.nextDouble()
val data = (0 until numRows).map { x =>
val label = Random.nextInt(2)
val values = Array.tabulate[Double](numCols) { c =>
if (c == numCols - 1) -0.1 else Random.nextDouble()
if (c == numCols - 1) -0.1 else Random.nextDouble
}
MLLabeledPoint(label, Vectors.dense(values))
(label, Vectors.dense(values))
}
sc.parallelize(labeledPoints)
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
}
val trainingRDD = buildDenseRDD().repartition(4)
val testRDD = buildDenseRDD().repartition(4).map(_.features.asInstanceOf[DenseVector])
val denseDF = buildDenseDataFrame().repartition(4)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers,
useExternalMemory = true)
xgBoostModel.predict(testRDD, missingValue = -0.1f).collect()
}
test("test consistency of prediction functions with RDD") {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSet = Classification.test
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
val testCollection = testRDD.collect()
for (i <- testSet.indices) {
assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values))
}
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1 = predRDD.collect()
assert(testRDD.count() === predResult1.length)
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
for (i <- predResult1.indices; j <- predResult1(i).indices) {
assert(predResult1(i)(j) === predResult2(i)(j))
}
}
test("test eval functions with RDD") {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers)
// Nan Zhu: deprecate it for now
// xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false)
}
test("test prediction functionality with empty partition") {
import DataUtils._
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers)
}
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testRDD = buildEmptyRDD()
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
println(xgBoostModel.predict(testRDD).collect().length === 0)
}
test("test use groupData") {
import DataUtils._
val trainingRDD = sc.parallelize(Ranking.train0, numSlices = 1).map(_.asML)
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0)
val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "rank:pairwise", "eval_metric" -> "ndcg", "groupData" -> trainGroupData)
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 2, nWorkers = 1)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()
assert(testRDD.count() === predResult1.length)
val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData)
assert(avgMetric contains "ndcg")
// If the labels were lost ndcg comes back as 1.0
assert(avgMetric.split('=')(1).toFloat < 1F)
}
test("test use nested groupData") {
import DataUtils._
val trainingRDD0 = sc.parallelize(Ranking.train0, numSlices = 1)
val trainingRDD1 = sc.parallelize(Ranking.train1, numSlices = 1)
val trainingRDD = trainingRDD0.union(trainingRDD1).map(_.asML)
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()
assert(testRDD.count() === predResult1.length)
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
val model = new XGBoostClassifier(paramMap).fit(denseDF)
model.transform(denseDF).collect()
}
test("training with spark parallelism checks disabled") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L).toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
}
test("isClassificationTask correctly classifies supported objectives") {
import org.scalatest.prop.TableDrivenPropertyChecks._
val objectives = Table(
("isClassificationTask", "params"),
(true, Map("obj_type" -> "classification")),
(false, Map("obj_type" -> "regression")),
(false, Map("objective" -> "rank:ndcg")),
(false, Map("objective" -> "rank:pairwise")),
(false, Map("objective" -> "rank:map")),
(false, Map("objective" -> "count:poisson")),
(true, Map("objective" -> "binary:logistic")),
(true, Map("objective" -> "binary:logitraw")),
(true, Map("objective" -> "multi:softmax")),
(true, Map("objective" -> "multi:softprob")),
(false, Map("objective" -> "reg:linear")),
(false, Map("objective" -> "reg:logistic")),
(false, Map("objective" -> "reg:gamma")),
(false, Map("objective" -> "reg:tweedie")))
forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) =>
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
}
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L,
"num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
test("training with checkpoint boosters") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
val paramMap = Map("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2).toMap
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
def error(model: XGBoostModel): Float = eval.eval(
model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix)
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8,
nWorkers = numWorkers)
assert(error(tmpModel) > error(prevModel))
assert(error(prevModel) > error(nextModel))
assert(error(nextModel) < 0.1)
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
}
}

View File

@ -1,133 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.rdd.RDD
import org.scalatest.FunSuite
class XGBoostModelSuite extends FunSuite with PerTest {
test("test model consistency after save and load") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
assert(evalResults < 0.1)
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
assert(loadedEvalResults == evalResults)
}
test("test save and load of different types of models") {
import DataUtils._
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
// validate regression model
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.setFeaturesCol("feature_col")
xgBoostModel.setLabelCol("label_col")
xgBoostModel.setPredictionCol("prediction_col")
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
assert(loadedXGBoostModel.getLabelCol == "label_col")
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
// classification model
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5).deep)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
// (multiclass) classification model
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
}
test("copy and predict ClassificationModel") {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testRDD = sc.parallelize(Classification.test).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
testCopy(model, testRDD)
}
test("copy and predict RegressionModel") {
import DataUtils._
val trainingRDD = sc.parallelize(Regression.train).map(_.asML)
val testRDD = sc.parallelize(Regression.test).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "reg:linear")
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
testCopy(model, testRDD)
}
private def testCopy(model: XGBoostModel, testRDD: RDD[Vector]): Unit = {
val modelCopy = model.copy(ParamMap.empty)
modelCopy.summary // Ensure no exception.
val expected = model.predict(testRDD).collect
assert(modelCopy.predict(testRDD).collect === expected)
}
}

View File

@ -0,0 +1,114 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
class XGBoostRegressorSuite extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") {
val trainingDM = new DMatrix(Regression.train.iterator)
val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:linear")
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
val model2 = new XGBoostRegressor(paramMap ++ Array("num_round" -> round,
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
assert(prediction1.indices.count { i =>
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
} < prediction1.length * 0.1)
}
test("Set params in XGBoost and MLlib way should produce same model") {
val trainingDF = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:linear",
"num_round" -> round,
"num_workers" -> numWorkers)
// Set params in XGBoost way
val model1 = new XGBoostRegressor(paramMap).fit(trainingDF)
// Set params in MLlib way
val model2 = new XGBoostRegressor()
.setEta(1)
.setMaxDepth(6)
.setSilent(1)
.setObjective("reg:linear")
.setNumRound(round)
.setNumWorkers(numWorkers)
.fit(trainingDF)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val prediction2 = model2.transform(testDF).select("prediction").collect()
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(math.abs(p1 - p2) <= 0.01f)
}
}
test("ranking: use group data") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
"group_col" -> "group")
val trainingDF = buildDataFrameWithGroup(Ranking.train)
val testDF = buildDataFrame(Ranking.test)
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
val prediction = model.transform(testDF).collect()
assert(testDF.count() === prediction.length)
}
test("use weight") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val testDF = buildDataFrame(Regression.test)
val model = new XGBoostRegressor(paramMap).setWeightCol("weight").fit(trainingDF)
val prediction = model.transform(testDF).collect()
val first = prediction.head.getAs[Double]("prediction")
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
}
}

View File

@ -1,138 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException}
import scala.util.Random
import org.apache.spark.SparkConf
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class XGBoostSparkPipelinePersistence extends FunSuite with PerTest
with BeforeAndAfterAll {
override def afterAll(): Unit = {
delete(new File("./testxgbPipe"))
delete(new File("./testxgbEst"))
delete(new File("./testxgbModel"))
delete(new File("./test2xgbModel"))
}
private def delete(f: File) {
if (f.exists()) {
if (f.isDirectory()) {
for (c <- f.listFiles()) {
delete(c)
}
}
if (!f.delete()) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
}
test("test persistence of XGBoostEstimator") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val xgbEstimator = new XGBoostEstimator(paramMap)
xgbEstimator.write.overwrite().save("./testxgbEst")
val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst")
val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap
paramMap.foreach {
case (k, v) => assert(v == loadedParamMap(k).toString)
}
}
test("test persistence of a complete pipeline") {
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val r = new Random(0)
val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features")
val xgbEstimator = new XGBoostEstimator(paramMap)
val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator))
pipeline.write.overwrite().save("testxgbPipe")
val loadedPipeline = Pipeline.read.load("testxgbPipe")
val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator]
val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap
paramMap.foreach {
case (k, v) => assert(v == loadedParamMap(k).toString)
}
}
test("test persistence of XGBoostModel") {
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val r = new Random(0)
// maybe move to shared context, but requires session to import implicits
val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(df.columns
.filter(!_.contains("label")))
.setOutputCol("features")
val xgbEstimator = new XGBoostEstimator(Map("num_round" -> 10,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")
)).setFeaturesCol("features").setLabelCol("label")
// separate
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
predModel.write.overwrite.save("test2xgbModel")
val same2Model = XGBoostModel.load("test2xgbModel")
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
val predParamMap = predModel.extractParamMap()
val same2ParamMap = same2Model.extractParamMap()
assert(predParamMap.get(predModel.useExternalMemory)
=== same2ParamMap.get(same2Model.useExternalMemory))
assert(predParamMap.get(predModel.featuresCol) === same2ParamMap.get(same2Model.featuresCol))
assert(predParamMap.get(predModel.predictionCol)
=== same2ParamMap.get(same2Model.predictionCol))
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
// chained
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
predictionModel.write.overwrite.save("testxgbModel")
val sameModel = PipelineModel.load("testxgbModel")
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
assert(java.util.Arrays.equals(
predictionModelXGB.booster.toByteArray,
sameModelXGB.booster.toByteArray
))
val predictionModelXGBParamMap = predictionModel.extractParamMap()
val sameModelXGBParamMap = sameModel.extractParamMap()
assert(predictionModelXGBParamMap.get(predictionModelXGB.useExternalMemory)
=== sameModelXGBParamMap.get(sameModelXGB.useExternalMemory))
assert(predictionModelXGBParamMap.get(predictionModelXGB.featuresCol)
=== sameModelXGBParamMap.get(sameModelXGB.featuresCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.predictionCol)
=== sameModelXGBParamMap.get(sameModelXGB.predictionCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
}
}

View File

@ -197,6 +197,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
booster.getFeatureScore(featureMap).asScala
}
def getVersion: Int = booster.getVersion
def toByteArray: Array[Byte] = {
booster.toByteArray
}