[jvm-packages] XGBoost Spark integration refactor (#3387)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * [jvm-packages] XGBoost Spark integration refactor. (#3313) * XGBoost Spark integration refactor. * Make corresponding update for xgboost4j-example * Address comments. * [jvm-packages] Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib (#3326) * Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib * Fix extra space. * [jvm-packages] XGBoost Spark supports ranking with group data. (#3369) * XGBoost Spark supports ranking with group data. * Use Iterator.duplicate to prevent OOM. * Update CheckpointManagerSuite.scala * Resolve conflicts
This commit is contained in:
parent
e6696337e4
commit
2c4359e914
@ -21,7 +21,7 @@ import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}
|
||||
@ -160,10 +160,10 @@ object SparkModelTuningTool {
|
||||
private def crossValidation(
|
||||
xgboostParam: Map[String, Any],
|
||||
trainingData: Dataset[_]): TrainValidationSplitModel = {
|
||||
val xgbEstimator = new XGBoostEstimator(xgboostParam).setFeaturesCol("features").
|
||||
val xgbEstimator = new XGBoostRegressor(xgboostParam).setFeaturesCol("features").
|
||||
setLabelCol("logSales")
|
||||
val paramGrid = new ParamGridBuilder()
|
||||
.addGrid(xgbEstimator.round, Array(20, 50))
|
||||
.addGrid(xgbEstimator.numRound, Array(20, 50))
|
||||
.addGrid(xgbEstimator.eta, Array(0.1, 0.4))
|
||||
.build()
|
||||
val tv = new TrainValidationSplit()
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.example.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.SparkConf
|
||||
|
||||
@ -45,9 +45,10 @@ object SparkWithDataFrame {
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgboostModel = XGBoost.trainWithDataFrame(
|
||||
trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true)
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> numRound,
|
||||
"nWorkers" -> args(1).toInt).toMap
|
||||
val xgboostModel = new XGBoostClassifier(paramMap).fit(trainDF)
|
||||
// xgboost-spark appends the column containing prediction results
|
||||
xgboostModel.transform(testDF).show()
|
||||
}
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
object SparkWithRDD {
|
||||
def main(args: Array[String]): Unit = {
|
||||
if (args.length != 5) {
|
||||
println(
|
||||
"usage: program num_of_rounds num_workers training_path test_path model_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||
implicit val sc = new SparkContext(sparkConf)
|
||||
val inputTrainPath = args(2)
|
||||
val inputTestPath = args(3)
|
||||
val outputModelPath = args(4)
|
||||
// number of iterations
|
||||
val numRound = args(0).toInt
|
||||
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).map(lp =>
|
||||
MLLabeledPoint(lp.label, new MLDenseVector(lp.features.toArray)))
|
||||
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath)
|
||||
.map(lp => new MLDenseVector(lp.features.toArray))
|
||||
// training parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||
useExternalMemory = true)
|
||||
xgboostModel.predict(testSet, missingValue = Float.NaN)
|
||||
// save model to HDFS path
|
||||
xgboostModel.saveModelAsHadoopFile(outputModelPath)
|
||||
}
|
||||
}
|
||||
@ -17,6 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.SparkContext
|
||||
@ -63,9 +64,9 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
val version = versions.max
|
||||
val fullPath = getPath(version)
|
||||
logger.info(s"Start training from previous booster at $fullPath")
|
||||
val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc)
|
||||
model.booster.booster.setVersion(version)
|
||||
model.booster
|
||||
val booster = SXGBoost.loadModel(fullPath)
|
||||
booster.booster.setVersion(version)
|
||||
booster
|
||||
} else {
|
||||
null
|
||||
}
|
||||
@ -76,12 +77,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
|
||||
*
|
||||
* @param checkpoint the checkpoint to save as an XGBoostModel
|
||||
*/
|
||||
private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = {
|
||||
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
||||
val fullPath = getPath(checkpoint.version)
|
||||
logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath")
|
||||
checkpoint.saveModelAsHadoopFile(fullPath)(sc)
|
||||
val fullPath = getPath(checkpoint.getVersion)
|
||||
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
|
||||
checkpoint.saveModel(fullPath)
|
||||
prevModelPaths.foreach(path => fs.delete(path, true))
|
||||
}
|
||||
|
||||
|
||||
@ -21,16 +21,15 @@ import java.nio.file.Files
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
|
||||
|
||||
@ -134,7 +133,7 @@ object XGBoost extends Serializable {
|
||||
fromBaseMarginsToArray(baseMargins), cacheDirName)
|
||||
|
||||
try {
|
||||
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
|
||||
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.train, params, round,
|
||||
@ -148,89 +147,6 @@ object XGBoost extends Serializable {
|
||||
}.cache()
|
||||
}
|
||||
|
||||
/**
|
||||
* Train XGBoost model with the DataFrame-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as DataFrame
|
||||
* @param params Map containing the parameters to configure XGBoost
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing The value which represents a missing value in the dataset
|
||||
* @param featureCol the name of input column, "features" as default value
|
||||
* @param labelCol the name of output column, "label" as default value
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def trainWithDataFrame(
|
||||
trainingData: Dataset[_],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN,
|
||||
featureCol: String = "features",
|
||||
labelCol: String = "label"): XGBoostModel = {
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
val estimator = new XGBoostEstimator(params)
|
||||
// assigning general parameters
|
||||
estimator.
|
||||
set(estimator.useExternalMemory, useExternalMemory).
|
||||
set(estimator.round, round).
|
||||
set(estimator.nWorkers, nWorkers).
|
||||
set(estimator.customObj, obj).
|
||||
set(estimator.customEval, eval).
|
||||
set(estimator.missing, missing).
|
||||
setFeaturesCol(featureCol).
|
||||
setLabelCol(labelCol).
|
||||
fit(trainingData)
|
||||
}
|
||||
|
||||
private[spark] def isClassificationTask(params: Map[String, Any]): Boolean = {
|
||||
val objective = params.getOrElse("objective", params.getOrElse("obj_type", null))
|
||||
objective != null && {
|
||||
val objStr = objective.toString
|
||||
objStr != "regression" && !objStr.startsWith("reg:") && objStr != "count:poisson" &&
|
||||
!objStr.startsWith("rank:")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train XGBoost model with the RDD-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as RDD
|
||||
* @param params Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@deprecated("Use XGBoost.trainWithRDD instead.")
|
||||
def train(
|
||||
trainingData: RDD[MLLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): XGBoostModel = {
|
||||
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||
}
|
||||
|
||||
private def overrideParamsAccordingToTaskCPUs(
|
||||
params: Map[String, Any],
|
||||
sc: SparkContext): Map[String, Any] = {
|
||||
@ -259,39 +175,8 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* Train XGBoost model with the RDD-represented data
|
||||
*
|
||||
* @param trainingData the training set represented as RDD
|
||||
* @param params Map containing the configuration entries
|
||||
* @param round the number of iterations
|
||||
* @param nWorkers the number of xgboost workers, 0 by default which means that the number of
|
||||
* workers equals to the partition number of trainingData RDD
|
||||
* @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default
|
||||
* @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing The value which represents a missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed
|
||||
* @return XGBoostModel when successful training
|
||||
* @return A tuple of the booster and the metrics used to build training summary
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def trainWithRDD(
|
||||
trainingData: RDD[MLLabeledPoint],
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nWorkers: Int,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): XGBoostModel = {
|
||||
import DataUtils._
|
||||
val xgbTrainingData = trainingData.map { case MLLabeledPoint(label, features) =>
|
||||
features.asXGB.copy(label = label.toFloat)
|
||||
}
|
||||
trainDistributed(xgbTrainingData, params, round, nWorkers, obj, eval,
|
||||
useExternalMemory, missing)
|
||||
}
|
||||
|
||||
@throws(classOf[XGBoostError])
|
||||
private[spark] def trainDistributed(
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
@ -301,7 +186,7 @@ object XGBoost extends Serializable {
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false,
|
||||
missing: Float = Float.NaN): XGBoostModel = {
|
||||
missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = {
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
@ -350,20 +235,15 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val isClsTask = isClassificationTask(params)
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
||||
sparkJobThread, isClsTask)
|
||||
if (isClsTask){
|
||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||
params.getOrElse("num_class", "2").toString.toInt
|
||||
}
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
||||
sparkJobThread)
|
||||
if (checkpointRound < round) {
|
||||
prevBooster = model.booster
|
||||
checkpointManager.updateCheckpoint(model)
|
||||
prevBooster = booster
|
||||
checkpointManager.updateCheckpoint(prevBooster)
|
||||
}
|
||||
model
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
@ -383,17 +263,14 @@ object XGBoost extends Serializable {
|
||||
private def postTrackerReturnProcessing(
|
||||
trackerReturnVal: Int,
|
||||
distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])],
|
||||
sparkJobThread: Thread,
|
||||
isClassificationTask: Boolean
|
||||
): XGBoostModel = {
|
||||
sparkJobThread: Thread): (Booster, Map[String, Array[Float]]) = {
|
||||
if (trackerReturnVal == 0) {
|
||||
// Copies of the final booster and the corresponding metrics
|
||||
// reside in each partition of the `distributedBoostersAndMetrics`.
|
||||
// Any of them can be used to create the model.
|
||||
val (booster, metrics) = distributedBoostersAndMetrics.first()
|
||||
val xgboostModel = XGBoostModel(booster, isClassificationTask)
|
||||
distributedBoostersAndMetrics.unpersist(false)
|
||||
xgboostModel.setSummary(XGBoostTrainingSummary(metrics))
|
||||
(booster, metrics)
|
||||
} else {
|
||||
try {
|
||||
if (sparkJobThread.isAlive) {
|
||||
@ -407,64 +284,6 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
private def loadGeneralModelParams(inputStream: FSDataInputStream): (String, String, String) = {
|
||||
val featureCol = inputStream.readUTF()
|
||||
val labelCol = inputStream.readUTF()
|
||||
val predictionCol = inputStream.readUTF()
|
||||
(featureCol, labelCol, predictionCol)
|
||||
}
|
||||
|
||||
private def setGeneralModelParams(
|
||||
featureCol: String,
|
||||
labelCol: String,
|
||||
predCol: String,
|
||||
xgBoostModel: XGBoostModel): XGBoostModel = {
|
||||
xgBoostModel.setFeaturesCol(featureCol)
|
||||
xgBoostModel.setLabelCol(labelCol)
|
||||
xgBoostModel.setPredictionCol(predCol)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path in HDFS-compatible file system
|
||||
*
|
||||
* @param modelPath The path of the file representing the model
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoopFile(modelPath: String)(implicit sparkContext: SparkContext):
|
||||
XGBoostModel = {
|
||||
val path = new Path(modelPath)
|
||||
val dataInStream = path.getFileSystem(sparkContext.hadoopConfiguration).open(path)
|
||||
val modelType = dataInStream.readUTF()
|
||||
val (featureCol, labelCol, predictionCol) = loadGeneralModelParams(dataInStream)
|
||||
modelType match {
|
||||
case "_cls_" =>
|
||||
val rawPredictionCol = dataInStream.readUTF()
|
||||
val numClasses = dataInStream.readInt()
|
||||
val thresholdLength = dataInStream.readInt()
|
||||
var thresholds: Array[Double] = null
|
||||
if (thresholdLength != -1) {
|
||||
thresholds = new Array[Double](thresholdLength)
|
||||
for (i <- 0 until thresholdLength) {
|
||||
thresholds(i) = dataInStream.readDouble()
|
||||
}
|
||||
}
|
||||
val xgBoostModel = new XGBoostClassificationModel(SXGBoost.loadModel(dataInStream))
|
||||
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel).
|
||||
asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(rawPredictionCol)
|
||||
if (thresholdLength != -1) {
|
||||
xgBoostModel.setThresholds(thresholds)
|
||||
}
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClasses
|
||||
xgBoostModel
|
||||
case "_reg_" =>
|
||||
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
||||
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
|
||||
case other =>
|
||||
throw new XGBoostError(s"Unknown model type $other. Supported types " +
|
||||
s"are: ['_reg_', '_cls_'].")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class Watches private(
|
||||
@ -489,12 +308,29 @@ private class Watches private(
|
||||
|
||||
private object Watches {
|
||||
|
||||
def buildGroups(groups: Seq[Int]): Seq[Int] = {
|
||||
val output = mutable.ArrayBuffer.empty[Int]
|
||||
var count = 1
|
||||
var i = 1
|
||||
while (i < groups.length) {
|
||||
if (groups(i) != groups(i - 1)) {
|
||||
output += count
|
||||
count = 1
|
||||
} else {
|
||||
count += 1
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
output += count
|
||||
output
|
||||
}
|
||||
|
||||
def apply(
|
||||
params: Map[String, Any],
|
||||
labeledPoints: Iterator[XGBLabeledPoint],
|
||||
baseMarginsOpt: Option[Array[Float]],
|
||||
cacheDirName: Option[String]): Watches = {
|
||||
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0)
|
||||
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
|
||||
val r = new Random(seed)
|
||||
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
|
||||
@ -506,8 +342,18 @@ private object Watches {
|
||||
|
||||
accepted
|
||||
}
|
||||
val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
|
||||
|
||||
val (trainIter1, trainIter2) = trainPoints.duplicate
|
||||
val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull)
|
||||
val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray
|
||||
trainMatrix.setGroup(trainGroups)
|
||||
|
||||
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
|
||||
if (trainTestRatio < 1.0) {
|
||||
val testGroups = buildGroups(testPoints.map(_.group)).toArray
|
||||
testMatrix.setGroup(testGroups)
|
||||
}
|
||||
|
||||
r.setSeed(seed)
|
||||
for (baseMargins <- baseMarginsOpt) {
|
||||
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
|
||||
@ -515,11 +361,6 @@ private object Watches {
|
||||
testMatrix.setBaseMargin(testMargin)
|
||||
}
|
||||
|
||||
// TODO: use group attribute from the points.
|
||||
if (params.contains("groupData") && params("groupData") != null) {
|
||||
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
|
||||
TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
new Watches(trainMatrix, testMatrix, cacheDirName)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,181 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||
import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, Param, ParamMap}
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
/**
|
||||
* class of the XGBoost model used for classification task
|
||||
*/
|
||||
class XGBoostClassificationModel private[spark](
|
||||
override val uid: String, booster: Booster)
|
||||
extends XGBoostModel(booster) {
|
||||
|
||||
def this(booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), booster)
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
// scalastyle:off
|
||||
|
||||
/**
|
||||
* whether to output raw margin
|
||||
*/
|
||||
final val outputMargin = new BooleanParam(this, "outputMargin", "whether to output untransformed margin value")
|
||||
|
||||
setDefault(outputMargin, false)
|
||||
|
||||
def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel]
|
||||
|
||||
/**
|
||||
* the name of the column storing the raw prediction value, either probabilities (as default) or
|
||||
* raw margin value
|
||||
*/
|
||||
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).")
|
||||
|
||||
setDefault(rawPredictionCol, "probabilities")
|
||||
|
||||
final def getRawPredictionCol: String = $(rawPredictionCol)
|
||||
|
||||
def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel]
|
||||
|
||||
/**
|
||||
* Thresholds in multi-class classification
|
||||
*/
|
||||
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
|
||||
|
||||
def getThresholds: Array[Double] = $(thresholds)
|
||||
|
||||
def setThresholds(value: Array[Double]): XGBoostClassificationModel =
|
||||
set(thresholds, value).asInstanceOf[XGBoostClassificationModel]
|
||||
|
||||
// scalastyle:on
|
||||
|
||||
// generate dataframe containing raw prediction column which is typed as Vector
|
||||
private def predictRaw(
|
||||
testSet: Dataset[_],
|
||||
temporalColName: Option[String] = None,
|
||||
forceTransformedScore: Option[Boolean] = None): DataFrame = {
|
||||
val predictRDD = produceRowRDD(testSet, forceTransformedScore.getOrElse($(outputMargin)))
|
||||
val colName = temporalColName.getOrElse($(rawPredictionCol))
|
||||
val tempColName = colName + "_arraytype"
|
||||
val dsWithArrayTypedRawPredCol = testSet.sparkSession.createDataFrame(predictRDD, schema = {
|
||||
testSet.schema.add(tempColName, ArrayType(FloatType, containsNull = false))
|
||||
})
|
||||
val transformerForProbabilitiesArray =
|
||||
(rawPredArray: mutable.WrappedArray[Float]) =>
|
||||
if (numClasses == 2) {
|
||||
Array(1 - rawPredArray(0), rawPredArray(0)).map(_.toDouble)
|
||||
} else {
|
||||
rawPredArray.map(_.toDouble).array
|
||||
}
|
||||
dsWithArrayTypedRawPredCol.withColumn(colName,
|
||||
udf((rawPredArray: mutable.WrappedArray[Float]) =>
|
||||
new MLDenseVector(transformerForProbabilitiesArray(rawPredArray))).apply(col(tempColName))).
|
||||
drop(tempColName)
|
||||
}
|
||||
|
||||
private def fromFeatureToPrediction(testSet: Dataset[_]): Dataset[_] = {
|
||||
val rawPredictionDF = predictRaw(testSet, Some("rawPredictionCol"))
|
||||
val predictionUDF = udf(raw2prediction _).apply(col("rawPredictionCol"))
|
||||
val tempDF = rawPredictionDF.withColumn($(predictionCol), predictionUDF)
|
||||
val allColumnNames = testSet.columns ++ Seq($(predictionCol))
|
||||
tempDF.select(allColumnNames(0), allColumnNames.tail: _*)
|
||||
}
|
||||
|
||||
private def argMax(vector: Array[Double]): Double = {
|
||||
vector.zipWithIndex.maxBy(_._1)._2
|
||||
}
|
||||
|
||||
private def raw2prediction(rawPrediction: MLDenseVector): Double = {
|
||||
if (!isDefined(thresholds)) {
|
||||
argMax(rawPrediction.values)
|
||||
} else {
|
||||
probability2prediction(rawPrediction)
|
||||
}
|
||||
}
|
||||
|
||||
private def probability2prediction(probability: MLDenseVector): Double = {
|
||||
if (!isDefined(thresholds)) {
|
||||
argMax(probability.values)
|
||||
} else {
|
||||
val thresholds: Array[Double] = getThresholds
|
||||
val scaledProbability =
|
||||
probability.values.zip(thresholds).map { case (p, t) =>
|
||||
if (t == 0.0) Double.PositiveInfinity else p / t
|
||||
}
|
||||
argMax(scaledProbability)
|
||||
}
|
||||
}
|
||||
|
||||
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
|
||||
transformSchema(testSet.schema, logging = true)
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".transform() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
if ($(outputMargin)) {
|
||||
setRawPredictionCol("margin")
|
||||
}
|
||||
var outputData = testSet
|
||||
var numColsOutput = 0
|
||||
if ($(rawPredictionCol).nonEmpty) {
|
||||
outputData = predictRaw(testSet)
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
if ($(rawPredictionCol).nonEmpty) {
|
||||
require(!$(outputMargin), "XGBoost does not support output final prediction with" +
|
||||
" untransformed margin. Please set predictionCol as \"\" when setting outputMargin as" +
|
||||
" true")
|
||||
val rawToPredUDF = udf(raw2prediction _).apply(col($(rawPredictionCol)))
|
||||
outputData = outputData.withColumn($(predictionCol), rawToPredUDF)
|
||||
} else {
|
||||
outputData = fromFeatureToPrediction(testSet)
|
||||
}
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if (numColsOutput == 0) {
|
||||
this.logWarning(s"$uid: XGBoostClassificationModel.transform() was called as NOOP" +
|
||||
" since no output columns were set.")
|
||||
}
|
||||
outputData.toDF()
|
||||
}
|
||||
|
||||
private[spark] var numOfClasses = 2
|
||||
|
||||
def numClasses: Int = numOfClasses
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
||||
val newModel = copyValues(new XGBoostClassificationModel(booster), extra)
|
||||
newModel.setSummary(summary)
|
||||
}
|
||||
|
||||
override protected def predict(features: MLVector): Double = {
|
||||
throw new Exception("XGBoost does not support online prediction ")
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,432 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.classification._
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
||||
|
||||
class XGBoostClassifier (
|
||||
override val uid: String,
|
||||
private val xgboostParams: Map[String, Any])
|
||||
extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
|
||||
with XGBoostClassifierParams with DefaultParamsWritable {
|
||||
|
||||
def this() = this(Identifiable.randomUID("xgbc"), Map[String, Any]())
|
||||
|
||||
def this(uid: String) = this(uid, Map[String, Any]())
|
||||
|
||||
def this(xgboostParams: Map[String, Any]) = this(
|
||||
Identifiable.randomUID("xgbc"), xgboostParams)
|
||||
|
||||
XGBoostToMLlibParams(xgboostParams)
|
||||
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
|
||||
|
||||
def setNumClass(value: Int): this.type = set(numClass, value)
|
||||
|
||||
// setters for general params
|
||||
def setNumRound(value: Int): this.type = set(numRound, value)
|
||||
|
||||
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
|
||||
|
||||
def setNthread(value: Int): this.type = set(nthread, value)
|
||||
|
||||
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
|
||||
|
||||
def setSilent(value: Int): this.type = set(silent, value)
|
||||
|
||||
def setMissing(value: Float): this.type = set(missing, value)
|
||||
|
||||
def setTimeoutRequestWorkers(value: Long): this.type = set(timeoutRequestWorkers, value)
|
||||
|
||||
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
|
||||
|
||||
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
|
||||
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
// setters for booster params
|
||||
def setBooster(value: String): this.type = set(booster, value)
|
||||
|
||||
def setEta(value: Double): this.type = set(eta, value)
|
||||
|
||||
def setGamma(value: Double): this.type = set(gamma, value)
|
||||
|
||||
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
|
||||
|
||||
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
|
||||
|
||||
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
|
||||
|
||||
def setSubsample(value: Double): this.type = set(subsample, value)
|
||||
|
||||
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
|
||||
|
||||
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
|
||||
|
||||
def setLambda(value: Double): this.type = set(lambda, value)
|
||||
|
||||
def setAlpha(value: Double): this.type = set(alpha, value)
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
|
||||
def setSketchEps(value: Double): this.type = set(sketchEps, value)
|
||||
|
||||
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
|
||||
|
||||
def setSampleType(value: String): this.type = set(sampleType, value)
|
||||
|
||||
def setNormalizeType(value: String): this.type = set(normalizeType, value)
|
||||
|
||||
def setRateDrop(value: Double): this.type = set(rateDrop, value)
|
||||
|
||||
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
|
||||
|
||||
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
|
||||
|
||||
// setters for learning params
|
||||
def setObjective(value: String): this.type = set(objective, value)
|
||||
|
||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
||||
|
||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
||||
|
||||
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
|
||||
|
||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
if ($(objective).startsWith("multi")) {
|
||||
// multi
|
||||
"merror"
|
||||
} else {
|
||||
// binary
|
||||
"error"
|
||||
}
|
||||
}
|
||||
|
||||
override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
|
||||
|
||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
||||
set(evalMetric, setupDefaultEvalMetric())
|
||||
}
|
||||
|
||||
val _numClasses = getNumClasses(dataset)
|
||||
if (isDefined(numClass) && $(numClass) != _numClasses) {
|
||||
throw new Exception("The number of classes in dataset doesn't match " +
|
||||
"\'num_class\' in xgboost params.")
|
||||
}
|
||||
|
||||
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
||||
lit(Float.NaN)
|
||||
} else {
|
||||
col($(baseMarginCol))
|
||||
}
|
||||
|
||||
val instances: RDD[XGBLabeledPoint] = dataset.select(
|
||||
col($(featuresCol)),
|
||||
col($(labelCol)).cast(FloatType),
|
||||
baseMargin.cast(FloatType),
|
||||
weight.cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, baseMargin = baseMargin, weight = weight)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
// All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing))
|
||||
val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
model
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostClassifier = defaultCopy(extra)
|
||||
}
|
||||
|
||||
object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
|
||||
|
||||
override def load(path: String): XGBoostClassifier = super.load(path)
|
||||
}
|
||||
|
||||
class XGBoostClassificationModel private[ml](
|
||||
override val uid: String,
|
||||
override val numClasses: Int,
|
||||
private[spark] val _booster: Booster)
|
||||
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
|
||||
with XGBoostClassifierParams with MLWritable with Serializable {
|
||||
|
||||
import XGBoostClassificationModel._
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, 2, null)
|
||||
|
||||
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
||||
|
||||
/**
|
||||
* Returns summary (e.g. train/test objective history) of model on the
|
||||
* training set. An exception is thrown if no summary is available.
|
||||
*/
|
||||
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
||||
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
||||
}
|
||||
|
||||
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
||||
trainingSummary = Some(summary)
|
||||
this
|
||||
}
|
||||
|
||||
// TODO: Make it public after we resolve performance issue
|
||||
private def margin(features: Vector): Array[Float] = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
|
||||
_booster.predict(data = dm, outPutMargin = true)(0)
|
||||
}
|
||||
|
||||
private def probability(features: Vector): Array[Float] = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(scala.collection.Iterator(features.asXGB))
|
||||
_booster.predict(data = dm, outPutMargin = false)(0)
|
||||
}
|
||||
|
||||
override def predict(features: Vector): Double = {
|
||||
throw new Exception("XGBoost-Spark does not support online prediction")
|
||||
}
|
||||
|
||||
// Actually we don't use this function at all, to make it pass compiler check.
|
||||
override def predictRaw(features: Vector): Vector = {
|
||||
throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
|
||||
}
|
||||
|
||||
// Actually we don't use this function at all, to make it pass compiler check.
|
||||
override def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
|
||||
throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
|
||||
}
|
||||
|
||||
// Generate raw prediction and probability prediction.
|
||||
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
||||
|
||||
val schema = StructType(dataset.schema.fields ++
|
||||
Seq(StructField(name = _rawPredictionCol, dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false)) ++
|
||||
Seq(StructField(name = _probabilityCol, dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||
|
||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = dataset.sparkSession.sparkContext.appName
|
||||
|
||||
val rdd = dataset.rdd.mapPartitions { rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cacheInfo = {
|
||||
if ($(useExternalMemory)) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
|
||||
try {
|
||||
val rawPredictionItr = {
|
||||
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator
|
||||
}
|
||||
val probabilityItr = {
|
||||
bBooster.value.predict(dm, outPutMargin = false).map(Row(_)).iterator
|
||||
}
|
||||
Rabit.shutdown()
|
||||
rowItr1.zip(rawPredictionItr).zip(probabilityItr).map {
|
||||
case ((originals: Row, rawPrediction: Row), probability: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
|
||||
}
|
||||
} finally {
|
||||
dm.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator[Row]()
|
||||
}
|
||||
}
|
||||
|
||||
bBooster.unpersist(blocking = false)
|
||||
|
||||
dataset.sparkSession.createDataFrame(rdd, schema)
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".transform() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
|
||||
// Output selected columns only.
|
||||
// This is a bit complicated since it tries to avoid repeated computation.
|
||||
var outputData = transformInternal(dataset)
|
||||
var numColsOutput = 0
|
||||
|
||||
val rawPredictionUDF = udf { (rawPrediction: mutable.WrappedArray[Float]) =>
|
||||
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
|
||||
}
|
||||
|
||||
val probabilityUDF = udf { (probability: mutable.WrappedArray[Float]) =>
|
||||
if (numClasses == 2) {
|
||||
Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble))
|
||||
} else {
|
||||
Vectors.dense(probability.map(_.toDouble).toArray)
|
||||
}
|
||||
}
|
||||
|
||||
val predictUDF = udf { (probability: mutable.WrappedArray[Float]) =>
|
||||
// From XGBoost probability to MLlib prediction
|
||||
val probabilities = if (numClasses == 2) {
|
||||
Array(1 - probability(0), probability(0)).map(_.toDouble)
|
||||
} else {
|
||||
probability.map(_.toDouble).toArray
|
||||
}
|
||||
probability2prediction(Vectors.dense(probabilities))
|
||||
}
|
||||
|
||||
if ($(rawPredictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if ($(probabilityCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if (numColsOutput == 0) {
|
||||
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
||||
" since no output columns were set.")
|
||||
}
|
||||
outputData
|
||||
.toDF
|
||||
.drop(col(_rawPredictionCol))
|
||||
.drop(col(_probabilityCol))
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
||||
val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra)
|
||||
newModel.setSummary(summary).setParent(parent)
|
||||
}
|
||||
|
||||
override def write: MLWriter =
|
||||
new XGBoostClassificationModel.XGBoostClassificationModelWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
|
||||
|
||||
private val _rawPredictionCol = "_rawPrediction"
|
||||
private val _probabilityCol = "_probability"
|
||||
|
||||
override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
|
||||
|
||||
override def load(path: String): XGBoostClassificationModel = super.load(path)
|
||||
|
||||
private[XGBoostClassificationModel]
|
||||
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// Save metadata and Params
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
// Save model data
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||
outputStream.writeInt(instance.numClasses)
|
||||
instance._booster.saveModel(outputStream)
|
||||
outputStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostClassificationModelReader extends MLReader[XGBoostClassificationModel] {
|
||||
|
||||
/** Checked against metadata when loading model */
|
||||
private val className = classOf[XGBoostClassificationModel].getName
|
||||
|
||||
override def load(path: String): XGBoostClassificationModel = {
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
|
||||
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
||||
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
|
||||
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
||||
val numClasses = dataInStream.readInt()
|
||||
|
||||
val booster = SXGBoost.loadModel(dataInStream)
|
||||
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
|
||||
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
||||
model
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,186 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.ml.Predictor
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.FloatType
|
||||
import org.apache.spark.sql.{Dataset, Row}
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
/**
|
||||
* XGBoost Estimator to produce a XGBoost model
|
||||
*/
|
||||
class XGBoostEstimator private[spark](
|
||||
override val uid: String, xgboostParams: Map[String, Any])
|
||||
extends Predictor[Vector, XGBoostEstimator, XGBoostModel]
|
||||
with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {
|
||||
|
||||
def this(xgboostParams: Map[String, Any]) =
|
||||
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
|
||||
|
||||
def this(uid: String) = this(uid, Map[String, Any]())
|
||||
|
||||
// called in fromXGBParamMapToParams only when eval_metric is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null))
|
||||
if (objFunc == null) {
|
||||
"rmse"
|
||||
} else {
|
||||
// compute default metric based on specified objective
|
||||
val isClassificationTask = XGBoost.isClassificationTask(xgboostParams)
|
||||
if (!isClassificationTask) {
|
||||
// default metric for regression or ranking
|
||||
if (objFunc.toString.startsWith("rank")) {
|
||||
"map"
|
||||
} else {
|
||||
"rmse"
|
||||
}
|
||||
} else {
|
||||
// default metric for classification
|
||||
if (objFunc.toString.startsWith("multi")) {
|
||||
// multi
|
||||
"merror"
|
||||
} else {
|
||||
// binary
|
||||
"error"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def fromXGBParamMapToParams(): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
params.find(_.name == paramName) match {
|
||||
case None =>
|
||||
case Some(_: DoubleParam) =>
|
||||
set(paramName, paramValue.toString.toDouble)
|
||||
case Some(_: BooleanParam) =>
|
||||
set(paramName, paramValue.toString.toBoolean)
|
||||
case Some(_: IntParam) =>
|
||||
set(paramName, paramValue.toString.toInt)
|
||||
case Some(_: FloatParam) =>
|
||||
set(paramName, paramValue.toString.toFloat)
|
||||
case Some(_: Param[_]) =>
|
||||
set(paramName, paramValue)
|
||||
}
|
||||
}
|
||||
if (xgboostParams.get("eval_metric").isEmpty) {
|
||||
set("eval_metric", setupDefaultEvalMetric())
|
||||
}
|
||||
}
|
||||
|
||||
fromXGBParamMapToParams()
|
||||
|
||||
private[spark] def fromParamsToXGBParamMap: Map[String, Any] = {
|
||||
val xgbParamMap = new mutable.HashMap[String, Any]()
|
||||
for (param <- params) {
|
||||
xgbParamMap += param.name -> $(param)
|
||||
}
|
||||
val r = xgbParamMap.toMap
|
||||
if (!XGBoost.isClassificationTask(r) || $(numClasses) == 2) {
|
||||
r - "num_class"
|
||||
} else {
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = {
|
||||
var newTrainingSet = trainingSet
|
||||
if (!trainingSet.columns.contains($(baseMarginCol))) {
|
||||
newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN))
|
||||
}
|
||||
if (!trainingSet.columns.contains($(weightCol))) {
|
||||
newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0))
|
||||
}
|
||||
newTrainingSet
|
||||
}
|
||||
|
||||
/**
|
||||
* produce a XGBoostModel by fitting the given dataset
|
||||
*/
|
||||
override def train(trainingSet: Dataset[_]): XGBoostModel = {
|
||||
val instances = ensureColumns(trainingSet).select(
|
||||
col($(featuresCol)),
|
||||
col($(labelCol)).cast(FloatType),
|
||||
col($(baseMarginCol)).cast(FloatType),
|
||||
col($(weightCol)).cast(FloatType)
|
||||
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
|
||||
}
|
||||
transformSchema(trainingSet.schema, logging = true)
|
||||
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
|
||||
val trainedModel = XGBoost.trainDistributed(instances, derivedXGBoosterParamMap,
|
||||
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing)).setParent(this)
|
||||
val returnedModel = copyValues(trainedModel, extractParamMap())
|
||||
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
|
||||
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
|
||||
}
|
||||
returnedModel
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostEstimator = {
|
||||
defaultCopy(extra).asInstanceOf[XGBoostEstimator]
|
||||
}
|
||||
|
||||
override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostEstimator extends MLReadable[XGBoostEstimator] {
|
||||
|
||||
override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader
|
||||
|
||||
override def load(path: String): XGBoostEstimator = super.load(path)
|
||||
|
||||
private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator)
|
||||
extends MLWriter {
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
require(instance.fromParamsToXGBParamMap("custom_eval") == null &&
|
||||
instance.fromParamsToXGBParamMap("custom_obj") == null,
|
||||
"we do not support persist XGBoostEstimator with customized evaluator and objective" +
|
||||
" function for now")
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] {
|
||||
|
||||
override def load(path: String): XGBoostEstimator = {
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc)
|
||||
val cls = Utils.classForName(metadata.className)
|
||||
val instance =
|
||||
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
|
||||
DefaultXGBoostParamsReader.getAndSetParams(instance, metadata)
|
||||
instance.asInstanceOf[XGBoostEstimator]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,387 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||
|
||||
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||
|
||||
import org.apache.spark.ml.PredictionModel
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||
import org.apache.spark.ml.param.{BooleanParam, ParamMap, Params}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
/**
|
||||
* the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
|
||||
*/
|
||||
abstract class XGBoostModel(protected var _booster: Booster)
|
||||
extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable
|
||||
with Params with MLWritable {
|
||||
|
||||
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
||||
|
||||
/**
|
||||
* Returns summary (e.g. train/test objective history) of model on the
|
||||
* training set. An exception is thrown if no summary is available.
|
||||
*/
|
||||
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
||||
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
||||
}
|
||||
|
||||
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
||||
trainingSummary = Some(summary)
|
||||
this
|
||||
}
|
||||
|
||||
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||
|
||||
// scalastyle:off
|
||||
|
||||
final val useExternalMemory = new BooleanParam(this, "use_external_memory",
|
||||
"whether to use external memory for prediction")
|
||||
|
||||
setDefault(useExternalMemory, false)
|
||||
|
||||
def setExternalMemory(value: Boolean): XGBoostModel = set(useExternalMemory, value)
|
||||
|
||||
// scalastyle:on
|
||||
|
||||
/**
|
||||
* Predict leaf instances with the given test set (represented as RDD)
|
||||
*
|
||||
* @param testSet test set represented as RDD
|
||||
*/
|
||||
def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Float]] = {
|
||||
import DataUtils._
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
if (testSamples.nonEmpty) {
|
||||
val dMatrix = new DMatrix(testSamples.map(_.asXGB))
|
||||
try {
|
||||
broadcastBooster.value.predictLeaf(dMatrix).iterator
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||
*
|
||||
* NOTE: you have to specify value of either eval or iter; when you specify both, this method
|
||||
* adopts the default eval metric of model
|
||||
*
|
||||
* @param evalDataset the dataset used for evaluation
|
||||
* @param evalName the name of evaluation
|
||||
* @param evalFunc the customized evaluation function, null by default to use the default metric
|
||||
* of model
|
||||
* @param iter the current iteration, -1 to be null to use customized evaluation functions
|
||||
* @param groupData group data specify each group size for ranking task. Top level corresponds
|
||||
* to partition id, second level is the group sizes.
|
||||
* @return the average metric over all partitions
|
||||
*/
|
||||
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||
iter: Int = -1, useExternalCache: Boolean = false,
|
||||
groupData: Seq[Seq[Int]] = null): String = {
|
||||
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
|
||||
val appName = evalDataset.context.appName
|
||||
val allEvalMetrics = evalDataset.mapPartitions {
|
||||
labeledPointsPartition =>
|
||||
import DataUtils._
|
||||
if (labeledPointsPartition.hasNext) {
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName = {
|
||||
if (broadcastUseExternalCache.value) {
|
||||
s"$appName-${TaskContext.get().stageId()}-$evalName" +
|
||||
s"-deval_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(labeledPointsPartition.map(_.asXGB), cacheFileName)
|
||||
try {
|
||||
if (groupData != null) {
|
||||
dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray)
|
||||
}
|
||||
(evalFunc, iter) match {
|
||||
case (null, _) => {
|
||||
val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
|
||||
val Array(evName, predNumeric) = predStr.split(":")
|
||||
Iterator(Some(evName, predNumeric.toFloat))
|
||||
}
|
||||
case _ => {
|
||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||
Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator(None)
|
||||
}
|
||||
}.filter(_.isDefined).collect()
|
||||
val evalPrefix = allEvalMetrics.map(_.get._1).head
|
||||
val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
|
||||
s"$evalPrefix = $evalMetricMean"
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict result with the given test set (represented as RDD)
|
||||
*
|
||||
* @param testSet test set represented as RDD
|
||||
* @param missingValue the specified value to represent the missing value
|
||||
*/
|
||||
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val sampleArray = testSamples.toArray
|
||||
val numRows = sampleArray.length
|
||||
if (numRows == 0) {
|
||||
Iterator()
|
||||
} else {
|
||||
val numColumns = sampleArray.head.size
|
||||
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
// translate to required format
|
||||
val flatSampleArray = new Array[Float](numRows * numColumns)
|
||||
for (i <- flatSampleArray.indices) {
|
||||
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
||||
}
|
||||
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
||||
try {
|
||||
broadcastBooster.value.predict(dMatrix).iterator
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict result with the given test set (represented as RDD)
|
||||
*
|
||||
* @param testSet test set represented as RDD
|
||||
* @param useExternalCache whether to use external cache for the test set
|
||||
* @param outputMargin whether to output raw untransformed margin value
|
||||
*/
|
||||
def predict(
|
||||
testSet: RDD[MLVector],
|
||||
useExternalCache: Boolean = false,
|
||||
outputMargin: Boolean = false): RDD[Array[Float]] = {
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||
val appName = testSet.context.appName
|
||||
testSet.mapPartitions { testSamples =>
|
||||
if (testSamples.nonEmpty) {
|
||||
import DataUtils._
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val cacheFileName = {
|
||||
if (useExternalCache) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val dMatrix = new DMatrix(testSamples.map(_.asXGB), cacheFileName)
|
||||
try {
|
||||
broadcastBooster.value.predict(dMatrix).iterator
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
dMatrix.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected def transformImpl(testSet: Dataset[_]): DataFrame
|
||||
|
||||
/**
|
||||
* append leaf index of each row as an additional column in the original dataset
|
||||
*
|
||||
* @return the original dataframe with an additional column containing prediction results
|
||||
*/
|
||||
def transformLeaf(testSet: Dataset[_]): DataFrame = {
|
||||
val predictRDD = produceRowRDD(testSet, predLeaf = true)
|
||||
setPredictionCol("predLeaf")
|
||||
transformSchema(testSet.schema, logging = true)
|
||||
testSet.sparkSession.createDataFrame(predictRDD, testSet.schema.add($(predictionCol),
|
||||
ArrayType(FloatType, containsNull = false)))
|
||||
}
|
||||
|
||||
protected def produceRowRDD(testSet: Dataset[_], outputMargin: Boolean = false,
|
||||
predLeaf: Boolean = false): RDD[Row] = {
|
||||
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = testSet.sparkSession.sparkContext.appName
|
||||
testSet.rdd.mapPartitions {
|
||||
rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[MLVector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cachePrefix = {
|
||||
if ($(useExternalMemory)) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
val testDataset = new DMatrix(vectorIterator.map(_.asXGB), cachePrefix)
|
||||
try {
|
||||
val rawPredictResults = {
|
||||
if (!predLeaf) {
|
||||
broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator
|
||||
} else {
|
||||
broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator
|
||||
}
|
||||
}
|
||||
Rabit.shutdown()
|
||||
// concatenate original data partition and predictions
|
||||
rowItr1.zip(rawPredictResults).map {
|
||||
case (originalColumns: Row, predictColumn: Row) =>
|
||||
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
|
||||
}
|
||||
} finally {
|
||||
testDataset.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator[Row]()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* produces the prediction results and append as an additional column in the original dataset
|
||||
* NOTE: the prediction results is kept as the original format of xgboost
|
||||
*
|
||||
* @return the original dataframe with an additional column containing prediction results
|
||||
*/
|
||||
override def transform(testSet: Dataset[_]): DataFrame = {
|
||||
transformImpl(testSet)
|
||||
}
|
||||
|
||||
private def saveGeneralModelParam(outputStream: FSDataOutputStream): Unit = {
|
||||
outputStream.writeUTF(getFeaturesCol)
|
||||
outputStream.writeUTF(getLabelCol)
|
||||
outputStream.writeUTF(getPredictionCol)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as to HDFS-compatible file system.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = {
|
||||
val path = new Path(modelPath)
|
||||
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
|
||||
// output model type
|
||||
this match {
|
||||
case model: XGBoostClassificationModel =>
|
||||
outputStream.writeUTF("_cls_")
|
||||
saveGeneralModelParam(outputStream)
|
||||
outputStream.writeUTF(model.getRawPredictionCol)
|
||||
outputStream.writeInt(model.numClasses)
|
||||
// threshold
|
||||
// threshold length
|
||||
if (!isDefined(model.thresholds)) {
|
||||
outputStream.writeInt(-1)
|
||||
} else {
|
||||
val thresholdLength = model.getThresholds.length
|
||||
outputStream.writeInt(thresholdLength)
|
||||
for (i <- 0 until thresholdLength) {
|
||||
outputStream.writeDouble(model.getThresholds(i))
|
||||
}
|
||||
}
|
||||
case model: XGBoostRegressionModel =>
|
||||
outputStream.writeUTF("_reg_")
|
||||
// eventual prediction col
|
||||
saveGeneralModelParam(outputStream)
|
||||
}
|
||||
// booster
|
||||
_booster.saveModel(outputStream)
|
||||
outputStream.close()
|
||||
}
|
||||
|
||||
def booster: Booster = _booster
|
||||
|
||||
def version: Int = this.booster.booster.getVersion
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
|
||||
|
||||
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostModel extends MLReadable[XGBoostModel] {
|
||||
private[spark] def apply(booster: Booster, isClassification: Boolean): XGBoostModel = {
|
||||
if (!isClassification) {
|
||||
new XGBoostRegressionModel(booster)
|
||||
} else {
|
||||
new XGBoostClassificationModel(booster)
|
||||
}
|
||||
}
|
||||
|
||||
override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader
|
||||
|
||||
override def load(path: String): XGBoostModel = super.load(path)
|
||||
|
||||
private[XGBoostModel] class XGBoostModelModelWriter(instance: XGBoostModel) extends MLWriter {
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
val dataPath = new Path(path, "data").toString
|
||||
instance.saveModelAsHadoopFile(dataPath)
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostModelModelReader extends MLReader[XGBoostModel] {
|
||||
|
||||
override def load(path: String): XGBoostModel = {
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
val dataPath = new Path(path, "data").toString
|
||||
// not used / all data resides in platform independent xgboost model file
|
||||
// val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
||||
XGBoost.loadModelFromHadoopFile(dataPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,61 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.ml.linalg.{Vector => MLVector}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType}
|
||||
|
||||
/**
|
||||
* class of XGBoost model used for regression task
|
||||
*/
|
||||
class XGBoostRegressionModel private[spark](override val uid: String, booster: Booster)
|
||||
extends XGBoostModel(booster) {
|
||||
|
||||
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostRegressionModel"), _booster)
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
|
||||
transformSchema(testSet.schema, logging = true)
|
||||
val predictRDD = produceRowRDD(testSet)
|
||||
val tempPredColName = $(predictionCol) + "_temp"
|
||||
val transformerForArrayTypedPredCol =
|
||||
udf((regressionResults: mutable.WrappedArray[Float]) => regressionResults(0))
|
||||
testSet.sparkSession.createDataFrame(predictRDD,
|
||||
schema = testSet.schema.add(tempPredColName, ArrayType(FloatType, containsNull = false))
|
||||
).withColumn(
|
||||
$(predictionCol),
|
||||
transformerForArrayTypedPredCol.apply(col(tempPredColName))).drop(tempPredColName)
|
||||
}
|
||||
|
||||
override protected def predict(features: MLVector): Double = {
|
||||
throw new Exception("XGBoost does not support online prediction for now")
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
||||
val newModel = copyValues(new XGBoostRegressionModel(booster), extra)
|
||||
newModel.setSummary(summary)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,356 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
|
||||
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
|
||||
with ParamMapFuncs
|
||||
|
||||
class XGBoostRegressor (
|
||||
override val uid: String,
|
||||
private val xgboostParams: Map[String, Any])
|
||||
extends Predictor[Vector, XGBoostRegressor, XGBoostRegressionModel]
|
||||
with XGBoostRegressorParams with DefaultParamsWritable {
|
||||
|
||||
def this() = this(Identifiable.randomUID("xgbr"), Map[String, Any]())
|
||||
|
||||
def this(uid: String) = this(uid, Map[String, Any]())
|
||||
|
||||
def this(xgboostParams: Map[String, Any]) = this(
|
||||
Identifiable.randomUID("xgbr"), xgboostParams)
|
||||
|
||||
XGBoostToMLlibParams(xgboostParams)
|
||||
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
|
||||
|
||||
def setGroupCol(value: String): this.type = set(groupCol, value)
|
||||
|
||||
// setters for general params
|
||||
def setNumRound(value: Int): this.type = set(numRound, value)
|
||||
|
||||
def setNumWorkers(value: Int): this.type = set(numWorkers, value)
|
||||
|
||||
def setNthread(value: Int): this.type = set(nthread, value)
|
||||
|
||||
def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
|
||||
|
||||
def setSilent(value: Int): this.type = set(silent, value)
|
||||
|
||||
def setMissing(value: Float): this.type = set(missing, value)
|
||||
|
||||
def setTimeoutRequestWorkers(value: Long): this.type = set(timeoutRequestWorkers, value)
|
||||
|
||||
def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
|
||||
|
||||
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
|
||||
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
// setters for booster params
|
||||
def setBooster(value: String): this.type = set(booster, value)
|
||||
|
||||
def setEta(value: Double): this.type = set(eta, value)
|
||||
|
||||
def setGamma(value: Double): this.type = set(gamma, value)
|
||||
|
||||
def setMaxDepth(value: Int): this.type = set(maxDepth, value)
|
||||
|
||||
def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
|
||||
|
||||
def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
|
||||
|
||||
def setSubsample(value: Double): this.type = set(subsample, value)
|
||||
|
||||
def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
|
||||
|
||||
def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
|
||||
|
||||
def setLambda(value: Double): this.type = set(lambda, value)
|
||||
|
||||
def setAlpha(value: Double): this.type = set(alpha, value)
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
|
||||
def setSketchEps(value: Double): this.type = set(sketchEps, value)
|
||||
|
||||
def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
|
||||
|
||||
def setSampleType(value: String): this.type = set(sampleType, value)
|
||||
|
||||
def setNormalizeType(value: String): this.type = set(normalizeType, value)
|
||||
|
||||
def setRateDrop(value: Double): this.type = set(rateDrop, value)
|
||||
|
||||
def setSkipDrop(value: Double): this.type = set(skipDrop, value)
|
||||
|
||||
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
|
||||
|
||||
// setters for learning params
|
||||
def setObjective(value: String): this.type = set(objective, value)
|
||||
|
||||
def setBaseScore(value: Double): this.type = set(baseScore, value)
|
||||
|
||||
def setEvalMetric(value: String): this.type = set(evalMetric, value)
|
||||
|
||||
def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
|
||||
|
||||
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
if ($(objective).startsWith("rank")) {
|
||||
"map"
|
||||
} else {
|
||||
"rmse"
|
||||
}
|
||||
}
|
||||
|
||||
override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
|
||||
|
||||
if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
|
||||
set(evalMetric, setupDefaultEvalMetric())
|
||||
}
|
||||
|
||||
val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) {
|
||||
lit(Float.NaN)
|
||||
} else {
|
||||
col($(baseMarginCol))
|
||||
}
|
||||
val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol))
|
||||
|
||||
val instances: RDD[XGBLabeledPoint] = dataset.select(
|
||||
col($(labelCol)).cast(FloatType),
|
||||
col($(featuresCol)),
|
||||
weight.cast(FloatType),
|
||||
group.cast(IntegerType),
|
||||
baseMargin.cast(FloatType)
|
||||
).rdd.map {
|
||||
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
|
||||
val (indices, values) = features match {
|
||||
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
|
||||
case v: DenseVector => (null, v.values.map(_.toFloat))
|
||||
}
|
||||
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
|
||||
}
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val derivedXGBParamMap = MLlib2XGBoostParams
|
||||
// All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
|
||||
val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap,
|
||||
$(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory),
|
||||
$(missing))
|
||||
val model = new XGBoostRegressionModel(uid, _booster)
|
||||
val summary = XGBoostTrainingSummary(_metrics)
|
||||
model.setSummary(summary)
|
||||
model
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostRegressor = defaultCopy(extra)
|
||||
}
|
||||
|
||||
object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
|
||||
|
||||
override def load(path: String): XGBoostRegressor = super.load(path)
|
||||
}
|
||||
|
||||
class XGBoostRegressionModel private[ml] (
|
||||
override val uid: String,
|
||||
private[spark] val _booster: Booster)
|
||||
extends PredictionModel[Vector, XGBoostRegressionModel]
|
||||
with XGBoostRegressorParams with MLWritable with Serializable {
|
||||
|
||||
import XGBoostRegressionModel._
|
||||
|
||||
// only called in copy()
|
||||
def this(uid: String) = this(uid, null)
|
||||
|
||||
private var trainingSummary: Option[XGBoostTrainingSummary] = None
|
||||
|
||||
/**
|
||||
* Returns summary (e.g. train/test objective history) of model on the
|
||||
* training set. An exception is thrown if no summary is available.
|
||||
*/
|
||||
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
|
||||
throw new IllegalStateException("No training summary available for this XGBoostModel")
|
||||
}
|
||||
|
||||
private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
|
||||
trainingSummary = Some(summary)
|
||||
this
|
||||
}
|
||||
|
||||
override def predict(features: Vector): Double = {
|
||||
throw new Exception("XGBoost-Spark does not support online prediction")
|
||||
}
|
||||
|
||||
private def transformInternal(dataset: Dataset[_]): DataFrame = {
|
||||
|
||||
val schema = StructType(dataset.schema.fields ++
|
||||
Seq(StructField(name = _originalPredictionCol, dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||
|
||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = dataset.sparkSession.sparkContext.appName
|
||||
|
||||
val rdd = dataset.rdd.mapPartitions { rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cacheInfo = {
|
||||
if ($(useExternalMemory)) {
|
||||
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
|
||||
try {
|
||||
val originalPredictionItr = {
|
||||
bBooster.value.predict(dm).map(Row(_)).iterator
|
||||
}
|
||||
Rabit.shutdown()
|
||||
rowItr1.zip(originalPredictionItr).map {
|
||||
case (originals: Row, originalPrediction: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
|
||||
}
|
||||
} finally {
|
||||
dm.delete()
|
||||
}
|
||||
} else {
|
||||
Iterator[Row]()
|
||||
}
|
||||
}
|
||||
|
||||
bBooster.unpersist(blocking = false)
|
||||
|
||||
dataset.sparkSession.createDataFrame(rdd, schema)
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
|
||||
// Output selected columns only.
|
||||
// This is a bit complicated since it tries to avoid repeated computation.
|
||||
var outputData = transformInternal(dataset)
|
||||
var numColsOutput = 0
|
||||
|
||||
val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
|
||||
originalPrediction(0).toDouble
|
||||
}
|
||||
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
outputData = outputData
|
||||
.withColumn($(predictionCol), predictUDF(col(_originalPredictionCol)))
|
||||
numColsOutput += 1
|
||||
}
|
||||
|
||||
if (numColsOutput == 0) {
|
||||
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
|
||||
" since no output columns were set.")
|
||||
}
|
||||
outputData.toDF.drop(col(_originalPredictionCol))
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
||||
val newModel = copyValues(new XGBoostRegressionModel(uid, _booster), extra)
|
||||
newModel.setSummary(summary).setParent(parent)
|
||||
}
|
||||
|
||||
override def write: MLWriter =
|
||||
new XGBoostRegressionModel.XGBoostRegressionModelWriter(this)
|
||||
}
|
||||
|
||||
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
|
||||
|
||||
private val _originalPredictionCol = "_originalPrediction"
|
||||
|
||||
override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
|
||||
|
||||
override def load(path: String): XGBoostRegressionModel = super.load(path)
|
||||
|
||||
private[XGBoostRegressionModel]
|
||||
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
// Save metadata and Params
|
||||
implicit val format = DefaultFormats
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
|
||||
// Save model data
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
||||
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
|
||||
instance._booster.saveModel(outputStream)
|
||||
outputStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
private class XGBoostRegressionModelReader extends MLReader[XGBoostRegressionModel] {
|
||||
|
||||
/** Checked against metadata when loading model */
|
||||
private val className = classOf[XGBoostRegressionModel].getName
|
||||
|
||||
override def load(path: String): XGBoostRegressionModel = {
|
||||
implicit val sc = super.sparkSession.sparkContext
|
||||
|
||||
val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
|
||||
|
||||
val dataPath = new Path(path, "data").toString
|
||||
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
|
||||
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
|
||||
|
||||
val booster = SXGBoost.loadModel(dataInStream)
|
||||
val model = new XGBoostRegressionModel(metadata.uid, booster)
|
||||
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
|
||||
model
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -20,40 +20,48 @@ import scala.collection.immutable.HashSet
|
||||
|
||||
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
|
||||
|
||||
trait BoosterParams extends Params {
|
||||
private[spark] trait BoosterParams extends Params {
|
||||
|
||||
/**
|
||||
* Booster to use, options: {'gbtree', 'gblinear', 'dart'}
|
||||
*/
|
||||
val boosterType = new Param[String](this, "booster",
|
||||
final val booster = new Param[String](this, "booster",
|
||||
s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}",
|
||||
(value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase))
|
||||
|
||||
final def getBooster: String = $(booster)
|
||||
|
||||
/**
|
||||
* step size shrinkage used in update to prevents overfitting. After each boosting step, we
|
||||
* can directly get the weights of new features and eta actually shrinks the feature weights
|
||||
* to make the boosting process more conservative. [default=0.3] range: [0,1]
|
||||
*/
|
||||
val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
|
||||
final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
|
||||
" overfitting. After each boosting step, we can directly get the weights of new features." +
|
||||
" and eta actually shrinks the feature weights to make the boosting process more conservative.",
|
||||
(value: Double) => value >= 0 && value <= 1)
|
||||
|
||||
final def getEta: Double = $(eta)
|
||||
|
||||
/**
|
||||
* minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||
* the larger, the more conservative the algorithm will be. [default=0] range: [0,
|
||||
* Double.MaxValue]
|
||||
*/
|
||||
val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a further" +
|
||||
" partition on a leaf node of the tree. the larger, the more conservative the algorithm" +
|
||||
" will be.", (value: Double) => value >= 0)
|
||||
final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " +
|
||||
"further partition on a leaf node of the tree. the larger, the more conservative the " +
|
||||
"algorithm will be.", (value: Double) => value >= 0)
|
||||
|
||||
final def getGamma: Double = $(gamma)
|
||||
|
||||
/**
|
||||
* maximum depth of a tree, increase this value will make model more complex / likely to be
|
||||
* overfitting. [default=6] range: [1, Int.MaxValue]
|
||||
*/
|
||||
val maxDepth = new IntParam(this, "max_depth", "maximum depth of a tree, increase this value" +
|
||||
" will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
|
||||
final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
|
||||
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
|
||||
|
||||
final def getMaxDepth: Int = $(maxDepth)
|
||||
|
||||
/**
|
||||
* minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
|
||||
@ -62,13 +70,15 @@ trait BoosterParams extends Params {
|
||||
* to minimum number of instances needed to be in each node. The larger, the more conservative
|
||||
* the algorithm will be. [default=1] range: [0, Double.MaxValue]
|
||||
*/
|
||||
val minChildWeight = new DoubleParam(this, "min_child_weight", "minimum sum of instance" +
|
||||
final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" +
|
||||
" weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
|
||||
" the sum of instance weight less than min_child_weight, then the building process will" +
|
||||
" give up further partitioning. In linear regression mode, this simply corresponds to minimum" +
|
||||
" number of instances needed to be in each node. The larger, the more conservative" +
|
||||
" the algorithm will be.", (value: Double) => value >= 0)
|
||||
|
||||
final def getMinChildWeight: Double = $(minChildWeight)
|
||||
|
||||
/**
|
||||
* Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it
|
||||
* means there is no constraint. If it is set to a positive value, it can help making the update
|
||||
@ -76,90 +86,113 @@ trait BoosterParams extends Params {
|
||||
* regression when class is extremely imbalanced. Set it to value of 1-10 might help control the
|
||||
* update. [default=0] range: [0, Double.MaxValue]
|
||||
*/
|
||||
val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow each" +
|
||||
" tree's weight" +
|
||||
final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " +
|
||||
"each tree's weight" +
|
||||
" estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
|
||||
" to a positive value, it can help making the update step more conservative. Usually this" +
|
||||
" parameter is not needed, but it might help in logistic regression when class is extremely" +
|
||||
" imbalanced. Set it to value of 1-10 might help control the update",
|
||||
(value: Double) => value >= 0)
|
||||
|
||||
final def getMaxDeltaStep: Double = $(maxDeltaStep)
|
||||
|
||||
/**
|
||||
* subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly
|
||||
* collected half of the data instances to grow trees and this will prevent overfitting.
|
||||
* [default=1] range:(0,1]
|
||||
*/
|
||||
val subSample = new DoubleParam(this, "subsample", "subsample ratio of the training instance." +
|
||||
" Setting it to 0.5 means that XGBoost randomly collected half of the data instances to" +
|
||||
" grow trees and this will prevent overfitting.", (value: Double) => value <= 1 && value > 0)
|
||||
final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " +
|
||||
"instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " +
|
||||
"instances to grow trees and this will prevent overfitting.",
|
||||
(value: Double) => value <= 1 && value > 0)
|
||||
|
||||
final def getSubsample: Double = $(subsample)
|
||||
|
||||
/**
|
||||
* subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
|
||||
*/
|
||||
val colSampleByTree = new DoubleParam(this, "colsample_bytree", "subsample ratio of columns" +
|
||||
" when constructing each tree.", (value: Double) => value <= 1 && value > 0)
|
||||
final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " +
|
||||
"columns when constructing each tree.", (value: Double) => value <= 1 && value > 0)
|
||||
|
||||
final def getColsampleBytree: Double = $(colsampleBytree)
|
||||
|
||||
/**
|
||||
* subsample ratio of columns for each split, in each level. [default=1] range: (0,1]
|
||||
*/
|
||||
val colSampleByLevel = new DoubleParam(this, "colsample_bylevel", "subsample ratio of columns" +
|
||||
" for each split, in each level.", (value: Double) => value <= 1 && value > 0)
|
||||
final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " +
|
||||
"columns for each split, in each level.", (value: Double) => value <= 1 && value > 0)
|
||||
|
||||
final def getColsampleBylevel: Double = $(colsampleBylevel)
|
||||
|
||||
/**
|
||||
* L2 regularization term on weights, increase this value will make model more conservative.
|
||||
* [default=1]
|
||||
*/
|
||||
val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, increase this" +
|
||||
" value will make model more conservative.", (value: Double) => value >= 0)
|
||||
final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " +
|
||||
"increase this value will make model more conservative.", (value: Double) => value >= 0)
|
||||
|
||||
final def getLambda: Double = $(lambda)
|
||||
|
||||
/**
|
||||
* L1 regularization term on weights, increase this value will make model more conservative.
|
||||
* [default=0]
|
||||
*/
|
||||
val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase this" +
|
||||
" value will make model more conservative.", (value: Double) => value >= 0)
|
||||
final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " +
|
||||
"this value will make model more conservative.", (value: Double) => value >= 0)
|
||||
|
||||
final def getAlpha: Double = $(alpha)
|
||||
|
||||
/**
|
||||
* The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'}
|
||||
* [default='auto']
|
||||
*/
|
||||
val treeMethod = new Param[String](this, "tree_method",
|
||||
final val treeMethod = new Param[String](this, "treeMethod",
|
||||
"The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}",
|
||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||
|
||||
final def getTreeMethod: String = $(treeMethod)
|
||||
|
||||
/**
|
||||
* growth policy for fast histogram algorithm
|
||||
*/
|
||||
val growthPolicty = new Param[String](this, "grow_policy",
|
||||
final val growPolicy = new Param[String](this, "growPolicy",
|
||||
"growth policy for fast histogram algorithm",
|
||||
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
|
||||
|
||||
final def getGrowPolicy: String = $(growPolicy)
|
||||
|
||||
/**
|
||||
* maximum number of bins in histogram
|
||||
*/
|
||||
val maxBins = new IntParam(this, "max_bin", "maximum number of bins in histogram",
|
||||
final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram",
|
||||
(value: Int) => value > 0)
|
||||
|
||||
final def getMaxBins: Int = $(maxBins)
|
||||
|
||||
/**
|
||||
* This is only used for approximate greedy algorithm.
|
||||
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
|
||||
* number of bins, this comes with theoretical guarantee with sketch accuracy.
|
||||
* [default=0.03] range: (0, 1)
|
||||
*/
|
||||
val sketchEps = new DoubleParam(this, "sketch_eps",
|
||||
final val sketchEps = new DoubleParam(this, "sketchEps",
|
||||
"This is only used for approximate greedy algorithm. This roughly translated into" +
|
||||
" O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" +
|
||||
" with theoretical guarantee with sketch accuracy.",
|
||||
(value: Double) => value < 1 && value > 0)
|
||||
|
||||
final def getSketchEps: Double = $(sketchEps)
|
||||
|
||||
/**
|
||||
* Control the balance of positive and negative weights, useful for unbalanced classes. A typical
|
||||
* value to consider: sum(negative cases) / sum(positive cases). [default=1]
|
||||
*/
|
||||
val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of positive" +
|
||||
" and negative weights, useful for unbalanced classes. A typical value to consider:" +
|
||||
final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " +
|
||||
"positive and negative weights, useful for unbalanced classes. A typical value to consider:" +
|
||||
" sum(negative cases) / sum(positive cases)")
|
||||
|
||||
final def getScalePosWeight: Double = $(scalePosWeight)
|
||||
|
||||
// Dart boosters
|
||||
|
||||
/**
|
||||
@ -167,72 +200,59 @@ trait BoosterParams extends Params {
|
||||
* Type of sampling algorithm. "uniform": dropped trees are selected uniformly.
|
||||
* "weighted": dropped trees are selected in proportion to weight. [default="uniform"]
|
||||
*/
|
||||
val sampleType = new Param[String](this, "sample_type", "type of sampling algorithm, options:" +
|
||||
" {'uniform', 'weighted'}",
|
||||
final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " +
|
||||
"options: {'uniform', 'weighted'}",
|
||||
(value: String) => BoosterParams.supportedSampleType.contains(value))
|
||||
|
||||
final def getSampleType: String = $(sampleType)
|
||||
|
||||
/**
|
||||
* Parameter of Dart booster.
|
||||
* type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"]
|
||||
*/
|
||||
val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
|
||||
final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" +
|
||||
" algorithm, options: {'tree', 'forest'}",
|
||||
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
|
||||
|
||||
final def getNormalizeType: String = $(normalizeType)
|
||||
|
||||
/**
|
||||
* Parameter of Dart booster.
|
||||
* dropout rate. [default=0.0] range: [0.0, 1.0]
|
||||
*/
|
||||
val rateDrop = new DoubleParam(this, "rate_drop", "dropout rate", (value: Double) =>
|
||||
final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) =>
|
||||
value >= 0 && value <= 1)
|
||||
|
||||
final def getRateDrop: Double = $(rateDrop)
|
||||
|
||||
/**
|
||||
* Parameter of Dart booster.
|
||||
* probability of skip dropout. If a dropout is skipped, new trees are added in the same manner
|
||||
* as gbtree. [default=0.0] range: [0.0, 1.0]
|
||||
*/
|
||||
val skipDrop = new DoubleParam(this, "skip_drop", "probability of skip dropout. If" +
|
||||
final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" +
|
||||
" a dropout is skipped, new trees are added in the same manner as gbtree.",
|
||||
(value: Double) => value >= 0 && value <= 1)
|
||||
|
||||
final def getSkipDrop: Double = $(skipDrop)
|
||||
|
||||
// linear booster
|
||||
/**
|
||||
* Parameter of linear booster
|
||||
* L2 regularization term on bias, default 0(no L1 reg on bias because it is not important)
|
||||
*/
|
||||
val lambdaBias = new DoubleParam(this, "lambda_bias", "L2 regularization term on bias, default" +
|
||||
" 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
|
||||
final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " +
|
||||
"default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
|
||||
|
||||
setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
final def getLambdaBias: Double = $(lambdaBias)
|
||||
|
||||
setDefault(booster -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||
growthPolicty -> "depthwise", maxBins -> 16,
|
||||
subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1,
|
||||
growPolicy -> "depthwise", maxBins -> 16,
|
||||
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
|
||||
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
||||
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
|
||||
|
||||
/**
|
||||
* Explains all params of this instance. See `explainParam()`.
|
||||
*/
|
||||
override def explainParams(): String = {
|
||||
// TODO: filter some parameters according to the booster type
|
||||
val boosterTypeStr = $(boosterType)
|
||||
val validParamList = {
|
||||
if (boosterTypeStr == "gblinear") {
|
||||
// gblinear
|
||||
params.filter(param => param.name == "lambda" ||
|
||||
param.name == "alpha" || param.name == "lambda_bias")
|
||||
} else if (boosterTypeStr != "dart") {
|
||||
// gbtree
|
||||
params.filter(param => param.name != "sample_type" &&
|
||||
param.name != "normalize_type" && param.name != "rate_drop" && param.name != "skip_drop")
|
||||
} else {
|
||||
// dart
|
||||
params.filter(_.name != "lambda_bias")
|
||||
}
|
||||
}
|
||||
explainParam(boosterType) + "\n" ++ validParamList.map(explainParam).mkString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object BoosterParams {
|
||||
|
||||
@ -16,84 +16,104 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import com.google.common.base.CaseFormat
|
||||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
|
||||
|
||||
import org.apache.spark.ml.param._
|
||||
import scala.collection.mutable
|
||||
|
||||
trait GeneralParams extends Params {
|
||||
private[spark] trait GeneralParams extends Params {
|
||||
|
||||
/**
|
||||
* The number of rounds for boosting
|
||||
*/
|
||||
val round = new IntParam(this, "num_round", "The number of rounds for boosting",
|
||||
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
final def getNumRound: Int = $(numRound)
|
||||
|
||||
/**
|
||||
* number of workers used to train xgboost model. default: 1
|
||||
*/
|
||||
val nWorkers = new IntParam(this, "nworkers", "number of workers used to run xgboost",
|
||||
final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
final def getNumWorkers: Int = $(numWorkers)
|
||||
|
||||
/**
|
||||
* number of threads used by per worker. default 1
|
||||
*/
|
||||
val numThreadPerTask = new IntParam(this, "nthread", "number of threads used by per worker",
|
||||
final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
final def getNthread: Int = $(nthread)
|
||||
|
||||
/**
|
||||
* whether to use external memory as cache. default: false
|
||||
*/
|
||||
val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external" +
|
||||
"memory as cache")
|
||||
final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
|
||||
"whether to use external memory as cache")
|
||||
|
||||
final def getUseExternalMemory: Boolean = $(useExternalMemory)
|
||||
|
||||
/**
|
||||
* 0 means printing running messages, 1 means silent mode. default: 0
|
||||
*/
|
||||
val silent = new IntParam(this, "silent",
|
||||
final val silent = new IntParam(this, "silent",
|
||||
"0 means printing running messages, 1 means silent mode.",
|
||||
(value: Int) => value >= 0 && value <= 1)
|
||||
|
||||
final def getSilent: Int = $(silent)
|
||||
|
||||
/**
|
||||
* customized objective function provided by user. default: null
|
||||
*/
|
||||
val customObj = new CustomObjParam(this, "custom_obj", "customized objective function " +
|
||||
final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
|
||||
"provided by user")
|
||||
|
||||
/**
|
||||
* customized evaluation function provided by user. default: null
|
||||
*/
|
||||
val customEval = new CustomEvalParam(this, "custom_eval", "customized evaluation function " +
|
||||
"provided by user")
|
||||
final val customEval = new CustomEvalParam(this, "customEval",
|
||||
"customized evaluation function provided by user")
|
||||
|
||||
/**
|
||||
* the value treated as missing. default: Float.NaN
|
||||
*/
|
||||
val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
final val missing = new FloatParam(this, "missing", "the value treated as missing")
|
||||
|
||||
final def getMissing: Float = $(missing)
|
||||
|
||||
/**
|
||||
* the maximum time to wait for the job requesting new workers. default: 30 minutes
|
||||
*/
|
||||
val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" +
|
||||
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
|
||||
" value is set smaller than or equal to 0.")
|
||||
final val timeoutRequestWorkers = new LongParam(this, "timeoutRequestWorkers", "the maximum " +
|
||||
"time to request new Workers if numCores are insufficient. The timeout will be disabled " +
|
||||
"if this value is set smaller than or equal to 0.")
|
||||
|
||||
final def getTimeoutRequestWorkers: Long = $(timeoutRequestWorkers)
|
||||
|
||||
/**
|
||||
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
|
||||
*/
|
||||
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
|
||||
"save checkpoints. If there are existing checkpoints in checkpoint_path. The job will load " +
|
||||
"the checkpoint with highest version as the starting point for training. If " +
|
||||
final val checkpointPath = new Param[String](this, "checkpointPath", "the hdfs folder to load " +
|
||||
"and save checkpoints. If there are existing checkpoints in checkpoint_path. The job will " +
|
||||
"load the checkpoint with highest version as the starting point for training. If " +
|
||||
"checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
|
||||
|
||||
final def getCheckpointPath: String = $(checkpointPath)
|
||||
|
||||
/**
|
||||
* Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
|
||||
* the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
|
||||
* also be set if the checkpoint interval is greater than 0.
|
||||
*/
|
||||
val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint " +
|
||||
"interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained model will get " +
|
||||
"checkpointed every 10 iterations. Note: `checkpoint_path` must also be set if the checkpoint" +
|
||||
" interval is greater than 0.", (interval: Int) => interval == -1 || interval >= 1)
|
||||
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval",
|
||||
"set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained " +
|
||||
"model will get checkpointed every 10 iterations. Note: `checkpoint_path` must also be " +
|
||||
"set if the checkpoint interval is greater than 0.",
|
||||
(interval: Int) => interval == -1 || interval >= 1)
|
||||
|
||||
final def getCheckpointInterval: Int = $(checkpointInterval)
|
||||
|
||||
/**
|
||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||
@ -122,15 +142,87 @@ trait GeneralParams extends Params {
|
||||
* Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
|
||||
* Ignored if the tracker implementation is "python".
|
||||
*/
|
||||
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")
|
||||
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
|
||||
|
||||
/** Random seed for the C++ part of XGBoost and train/test splitting. */
|
||||
val seed = new LongParam(this, "seed", "random seed")
|
||||
final val seed = new LongParam(this, "seed", "random seed")
|
||||
|
||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||
final def getSeed: Long = $(seed)
|
||||
|
||||
setDefault(numRound -> 1, numWorkers -> 1, nthread -> 1,
|
||||
useExternalMemory -> false, silent -> 0,
|
||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
||||
checkpointPath -> "", checkpointInterval -> -1
|
||||
)
|
||||
}
|
||||
|
||||
trait HasBaseMarginCol extends Params {
|
||||
|
||||
/**
|
||||
* Param for initial prediction (aka base margin) column name.
|
||||
* @group param
|
||||
*/
|
||||
final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
|
||||
"Initial prediction (aka base margin) column name.")
|
||||
|
||||
/** @group getParam */
|
||||
final def getBaseMarginCol: String = $(baseMarginCol)
|
||||
}
|
||||
|
||||
trait HasGroupCol extends Params {
|
||||
|
||||
/**
|
||||
* Param for group column name.
|
||||
* @group param
|
||||
*/
|
||||
final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
|
||||
|
||||
/** @group getParam */
|
||||
final def getGroupCol: String = $(groupCol)
|
||||
|
||||
}
|
||||
|
||||
trait HasNumClass extends Params {
|
||||
|
||||
/**
|
||||
* number of classes
|
||||
*/
|
||||
final val numClass = new IntParam(this, "numClass", "number of classes")
|
||||
|
||||
/** @group getParam */
|
||||
final def getNumClass: Int = $(numClass)
|
||||
}
|
||||
|
||||
private[spark] trait ParamMapFuncs extends Params {
|
||||
|
||||
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
params.find(_.name == name) match {
|
||||
case None =>
|
||||
case Some(_: DoubleParam) =>
|
||||
set(name, paramValue.toString.toDouble)
|
||||
case Some(_: BooleanParam) =>
|
||||
set(name, paramValue.toString.toBoolean)
|
||||
case Some(_: IntParam) =>
|
||||
set(name, paramValue.toString.toInt)
|
||||
case Some(_: FloatParam) =>
|
||||
set(name, paramValue.toString.toFloat)
|
||||
case Some(_: Param[_]) =>
|
||||
set(name, paramValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def MLlib2XGBoostParams: Map[String, Any] = {
|
||||
val xgboostParams = new mutable.HashMap[String, Any]()
|
||||
for (param <- params) {
|
||||
if (isDefined(param)) {
|
||||
val name = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, param.name)
|
||||
xgboostParams += name -> $(param)
|
||||
}
|
||||
}
|
||||
xgboostParams.toMap
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,76 +20,70 @@ import scala.collection.immutable.HashSet
|
||||
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
trait LearningTaskParams extends Params {
|
||||
|
||||
/**
|
||||
* number of tasks to learn
|
||||
*/
|
||||
val numClasses = new IntParam(this, "num_class", "number of classes")
|
||||
private[spark] trait LearningTaskParams extends Params {
|
||||
|
||||
/**
|
||||
* Specify the learning task and the corresponding learning objective.
|
||||
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
|
||||
* multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:linear
|
||||
*/
|
||||
val objective = new Param[String](this, "objective", "objective function used for training," +
|
||||
s" options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||
final val objective = new Param[String](this, "objective", "objective function used for " +
|
||||
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
|
||||
(value: String) => LearningTaskParams.supportedObjective.contains(value))
|
||||
|
||||
final def getObjective: String = $(objective)
|
||||
|
||||
/**
|
||||
* the initial prediction score of all instances, global bias. default=0.5
|
||||
*/
|
||||
val baseScore = new DoubleParam(this, "base_score", "the initial prediction score of all" +
|
||||
final val baseScore = new DoubleParam(this, "baseScore", "the initial prediction score of all" +
|
||||
" instances, global bias")
|
||||
|
||||
final def getBaseScore: Double = $(baseScore)
|
||||
|
||||
/**
|
||||
* evaluation metrics for validation data, a default metric will be assigned according to
|
||||
* objective(rmse for regression, and error for classification, mean average precision for
|
||||
* ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, aucpr, ndcg, map,
|
||||
* gamma-deviance
|
||||
*/
|
||||
val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" +
|
||||
" data, a default metric will be assigned according to objective (rmse for regression, and" +
|
||||
" error for classification, mean average precision for ranking), options: " +
|
||||
s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
|
||||
"validation data, a default metric will be assigned according to objective " +
|
||||
"(rmse for regression, and error for classification, mean average precision for ranking), " +
|
||||
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
|
||||
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
|
||||
|
||||
final def getEvalMetric: String = $(evalMetric)
|
||||
|
||||
/**
|
||||
* group data specify each group sizes for ranking task. To correspond to partition of
|
||||
* training data, it is nested.
|
||||
*/
|
||||
val groupData = new GroupDataParam(this, "groupData", "group data specify each group size" +
|
||||
" for ranking task. To correspond to partition of training data, it is nested.")
|
||||
|
||||
/**
|
||||
* Initial prediction (aka base margin) column name.
|
||||
*/
|
||||
val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name")
|
||||
|
||||
/**
|
||||
* Instance weights column name.
|
||||
*/
|
||||
val weightCol = new Param[String](this, "weightCol", "weight column name")
|
||||
final val groupData = new GroupDataParam(this, "groupData", "group data specify each group " +
|
||||
"size for ranking task. To correspond to partition of training data, it is nested.")
|
||||
|
||||
/**
|
||||
* Fraction of training points to use for testing.
|
||||
*/
|
||||
val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||
"fraction of training points to use for testing",
|
||||
ParamValidators.inRange(0, 1))
|
||||
|
||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
||||
|
||||
/**
|
||||
* If non-zero, the training will be stopped after a specified number
|
||||
* of consecutive increases in any evaluation metric.
|
||||
*/
|
||||
val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
|
||||
final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
|
||||
"number of rounds of decreasing eval metric to tolerate before " +
|
||||
"stopping the training",
|
||||
(value: Int) => value == 0 || value > 1)
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
|
||||
baseMarginCol -> "baseMargin", weightCol -> "weight", trainTestRatio -> 1.0,
|
||||
numEarlyStoppingRounds -> 0)
|
||||
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
|
||||
|
||||
setDefault(objective -> "reg:linear", baseScore -> 0.5, groupData -> null,
|
||||
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
|
||||
}
|
||||
|
||||
private[spark] object LearningTaskParams {
|
||||
|
||||
@ -1,75 +0,0 @@
|
||||
0 1:985.574005058 2:320.223538037 3:0.621236086198
|
||||
0 1:1010.52917943 2:635.535543082 3:2.14984030531
|
||||
0 1:1012.91900422 2:132.387300057 3:0.488761066665
|
||||
0 1:990.829194034 2:135.102081162 3:0.747701610673
|
||||
0 1:1007.05103629 2:154.289183562 3:0.464118249201
|
||||
0 1:994.9573036 2:317.483732878 3:0.0313685555674
|
||||
0 1:987.8071541 2:731.349178363 3:0.244616944245
|
||||
1 1:10.0349544469 2:2.29750906143 3:36.4949974282
|
||||
0 1:9.92953881383 2:5.39134047297 3:120.041297548
|
||||
0 1:10.0909866713 2:9.06191026312 3:138.807825798
|
||||
1 1:10.2090970614 2:0.0784495944448 3:58.207703565
|
||||
0 1:9.85695905893 2:9.99500727713 3:56.8610243778
|
||||
1 1:10.0805758547 2:0.0410805760559 3:222.102302076
|
||||
0 1:10.1209914486 2:9.9729127088 3:171.888238763
|
||||
0 1:10.0331939798 2:0.853339303793 3:311.181328375
|
||||
0 1:9.93901762951 2:2.72757449146 3:78.4859514413
|
||||
0 1:10.0752365346 2:9.18695328235 3:49.8520256553
|
||||
1 1:10.0456548902 2:0.270936043122 3:123.462958597
|
||||
0 1:10.0568923673 2:0.82997113263 3:44.9391426001
|
||||
0 1:9.8214143472 2:0.277538931578 3:15.4217659578
|
||||
0 1:9.95258604431 2:8.69564346094 3:255.513470671
|
||||
0 1:9.91934976357 2:7.72809741413 3:82.171591817
|
||||
0 1:10.043239582 2:8.64168255553 3:38.9657919329
|
||||
1 1:10.0236147929 2:0.0496662263659 3:4.40889812286
|
||||
1 1:1001.85585324 2:3.75646886071 3:0.0179224994842
|
||||
0 1:1014.25578571 2:0.285765311201 3:0.510329864983
|
||||
1 1:1002.81422786 2:9.77676280375 3:0.433705951912
|
||||
1 1:998.072711553 2:2.82100686538 3:0.889829076909
|
||||
0 1:1003.77395036 2:2.55916592114 3:0.0359402151496
|
||||
1 1:10.0807877782 2:4.98513959013 3:47.5266363559
|
||||
0 1:10.0015013081 2:9.94302478763 3:78.3697486277
|
||||
1 1:10.0441936789 2:0.305091816635 3:56.8213984987
|
||||
0 1:9.94257106618 2:7.23909568913 3:442.463339039
|
||||
1 1:9.86479307916 2:6.41701315844 3:55.1365304834
|
||||
0 1:10.0428628516 2:9.98466447697 3:0.391632812588
|
||||
0 1:9.94445884566 2:9.99970945878 3:260.438436534
|
||||
1 1:9.84641392823 2:225.78051312 3:1.00525978847
|
||||
1 1:9.86907690608 2:26.8971083147 3:0.577959255991
|
||||
0 1:10.0177314626 2:0.110585342313 3:2.30545043031
|
||||
0 1:10.0688190907 2:412.023866234 3:1.22421542264
|
||||
0 1:10.1251769646 2:13.8212202925 3:0.129171734504
|
||||
0 1:10.0840758802 2:407.359097187 3:0.477000870705
|
||||
0 1:10.1007458705 2:987.183625145 3:0.149385677415
|
||||
0 1:9.86472656059 2:169.559640615 3:0.147221652519
|
||||
0 1:9.94207419238 2:507.290053755 3:0.41996207214
|
||||
0 1:9.9671005502 2:1.62610457716 3:0.408173666788
|
||||
0 1:1010.57126596 2:9.06673707562 3:0.672092284372
|
||||
0 1:1001.6718262 2:9.53203990055 3:4.7364050044
|
||||
0 1:995.777341384 2:4.43847316256 3:2.07229073634
|
||||
0 1:1002.95701386 2:5.51711016665 3:1.24294450546
|
||||
0 1:1016.0988238 2:0.626468941906 3:0.105627919134
|
||||
0 1:1013.67571419 2:0.042315529666 3:0.717619310322
|
||||
1 1:994.747747892 2:6.01989364024 3:0.772910130015
|
||||
1 1:991.654593872 2:7.35575736952 3:1.19822091548
|
||||
0 1:1008.47101732 2:8.28240754909 3:0.229582481359
|
||||
0 1:1000.81975227 2:1.52448354056 3:0.096441660362
|
||||
0 1:10.0900922344 2:322.656649307 3:57.8149073088
|
||||
1 1:10.0868337371 2:2.88652339174 3:54.8865514572
|
||||
0 1:10.0988984137 2:979.483832657 3:52.6809830901
|
||||
0 1:9.97678959238 2:665.770979738 3:481.069628909
|
||||
0 1:9.78554312773 2:257.309358658 3:47.7324475232
|
||||
0 1:10.0985967566 2:935.896512941 3:138.937052808
|
||||
0 1:10.0522252319 2:876.376299607 3:6.00373510669
|
||||
1 1:9.88065229501 2:9.99979825653 3:0.0674603696149
|
||||
0 1:10.0483244098 2:0.0653852316381 3:0.130679349938
|
||||
1 1:9.99685215607 2:1.76602542774 3:0.2551321159
|
||||
0 1:9.99750159428 2:1.01591534436 3:0.145445506504
|
||||
1 1:9.97380908941 2:0.940048645571 3:0.411805696316
|
||||
0 1:9.99977678382 2:6.91329929641 3:5.57858201258
|
||||
0 1:978.876096381 2:933.775364741 3:0.579170824236
|
||||
0 1:998.381016406 2:220.940470582 3:2.01491778565
|
||||
0 1:987.917644594 2:8.74667873567 3:0.364006099758
|
||||
0 1:1000.20994892 2:25.2945450565 3:3.5684398964
|
||||
0 1:1014.57141264 2:675.593540733 3:0.164174055535
|
||||
0 1:998.867283535 2:765.452750642 3:0.818425293238
|
||||
@ -1,10 +0,0 @@
|
||||
7
|
||||
7
|
||||
10
|
||||
5
|
||||
7
|
||||
10
|
||||
10
|
||||
7
|
||||
6
|
||||
6
|
||||
@ -1,74 +0,0 @@
|
||||
0 1:10.2143092481 2:273.576539531 3:137.111774354
|
||||
0 1:10.0366658918 2:842.469052609 3:2.32134375927
|
||||
0 1:10.1281202091 2:395.654057342 3:35.4184893063
|
||||
0 1:10.1443721289 2:960.058461049 3:272.887070637
|
||||
0 1:10.1353234784 2:535.51304462 3:2.15393842032
|
||||
1 1:10.0451640374 2:216.733858424 3:55.6533298016
|
||||
1 1:9.94254592171 2:44.5985537358 3:304.614176871
|
||||
0 1:10.1319257181 2:613.545504487 3:5.42391587912
|
||||
0 1:1020.63622468 2:997.476744201 3:0.509425590461
|
||||
0 1:986.304585519 2:822.669937965 3:0.605133561808
|
||||
1 1:1012.66863221 2:26.7185759069 3:0.0875458784828
|
||||
0 1:995.387656321 2:81.8540176995 3:0.691999430068
|
||||
0 1:1020.6587198 2:848.826964547 3:0.540159430526
|
||||
1 1:1003.81573853 2:379.84350931 3:0.0083682925194
|
||||
0 1:1021.60921516 2:641.376951467 3:1.12339054807
|
||||
0 1:1000.17585041 2:122.107138713 3:1.09906375372
|
||||
1 1:987.64802348 2:5.98448541152 3:0.124241987204
|
||||
1 1:9.94610136583 2:346.114985897 3:0.387708236565
|
||||
0 1:9.96812192337 2:313.278109696 3:0.00863026595671
|
||||
0 1:10.0181739194 2:36.7378924562 3:2.92179879835
|
||||
0 1:9.89000102695 2:164.273723971 3:0.685222591968
|
||||
0 1:10.1555212436 2:320.451459462 3:2.01341536261
|
||||
0 1:10.0085727613 2:999.767117646 3:0.462294934168
|
||||
1 1:9.93099658724 2:5.17478203909 3:0.213855205032
|
||||
0 1:10.0629454957 2:663.088181857 3:0.049022351462
|
||||
0 1:10.1109732417 2:734.904569784 3:1.6998450094
|
||||
0 1:1006.6015266 2:505.023453703 3:1.90870566777
|
||||
0 1:991.865769489 2:245.437343115 3:0.475109744256
|
||||
0 1:998.682734072 2:950.041057232 3:1.9256314201
|
||||
0 1:1005.02207209 2:2.9619314197 3:0.0517146822357
|
||||
0 1:1002.54526214 2:860.562681899 3:0.915687092848
|
||||
0 1:1000.38847359 2:808.416525088 3:0.209690673808
|
||||
1 1:992.557818382 2:373.889409453 3:0.107571728577
|
||||
0 1:1002.07722137 2:997.329626371 3:1.06504260496
|
||||
0 1:1000.40504333 2:949.832139189 3:0.539159980327
|
||||
0 1:10.1460179902 2:8.86082969819 3:135.953842715
|
||||
1 1:9.98529296553 2:2.87366448495 3:1.74249892194
|
||||
0 1:9.88942676744 2:9.4031821056 3:149.473066381
|
||||
1 1:10.0192953341 2:1.99685737576 3:1.79502473397
|
||||
0 1:10.0110654379 2:8.13112593726 3:87.7765628103
|
||||
0 1:997.148677047 2:733.936190093 3:1.49298494242
|
||||
0 1:1008.70465919 2:957.121652078 3:0.217414013634
|
||||
1 1:997.356154278 2:541.599587807 3:0.100855972216
|
||||
0 1:999.615897283 2:943.700501824 3:0.862874175879
|
||||
1 1:997.36859077 2:0.200859940848 3:0.13601892182
|
||||
0 1:10.0423255624 2:1.73855202168 3:0.956695338485
|
||||
1 1:9.88440755486 2:9.9994600678 3:0.305080529665
|
||||
0 1:10.0891026412 2:3.28031719474 3:0.364450973697
|
||||
0 1:9.90078644258 2:8.77839663617 3:0.456660574479
|
||||
1 1:9.79380029711 2:8.77220326156 3:0.527292005175
|
||||
0 1:9.93613887011 2:9.76270841268 3:1.40865693823
|
||||
0 1:10.0009239007 2:7.29056178263 3:0.498015866607
|
||||
0 1:9.96603319905 2:5.12498000925 3:0.517492532783
|
||||
0 1:10.0923827222 2:2.76652583955 3:1.56571226159
|
||||
1 1:10.0983782035 2:587.788120694 3:0.031756483687
|
||||
1 1:9.91397225464 2:994.527496819 3:3.72092164978
|
||||
0 1:10.1057472738 2:2.92894440088 3:0.683506438532
|
||||
0 1:10.1014053354 2:959.082038017 3:1.07039624129
|
||||
0 1:10.1433253044 2:322.515119317 3:0.51408278993
|
||||
1 1:9.82832510699 2:637.104433908 3:0.250272776427
|
||||
0 1:1000.49729075 2:2.75336888111 3:0.576634423274
|
||||
1 1:984.90338088 2:0.0295435794035 3:1.26273339929
|
||||
0 1:1001.53811442 2:4.64164410861 3:0.0293389959504
|
||||
1 1:995.875898395 2:5.08223403205 3:0.382330566779
|
||||
0 1:996.405937252 2:6.26395190757 3:0.453645816611
|
||||
0 1:10.0165140779 2:340.126072514 3:0.220794603312
|
||||
0 1:9.93482824816 2:951.672000448 3:0.124406293612
|
||||
0 1:10.1700278554 2:0.0140985961008 3:0.252452256311
|
||||
0 1:9.99825079542 2:950.382643896 3:0.875382402062
|
||||
0 1:9.87316410028 2:686.788257829 3:0.215886999825
|
||||
0 1:10.2893240654 2:89.3947931451 3:0.569578232133
|
||||
0 1:9.98689192703 2:0.430107535413 3:2.99869831728
|
||||
0 1:10.1365175107 2:972.279245093 3:0.0865099386744
|
||||
0 1:9.90744703306 2:50.810461183 3:3.00863325197
|
||||
@ -1,10 +0,0 @@
|
||||
8
|
||||
9
|
||||
9
|
||||
9
|
||||
5
|
||||
5
|
||||
9
|
||||
6
|
||||
5
|
||||
9
|
||||
@ -1,10 +0,0 @@
|
||||
7
|
||||
5
|
||||
9
|
||||
6
|
||||
6
|
||||
8
|
||||
7
|
||||
6
|
||||
5
|
||||
7
|
||||
@ -0,0 +1,66 @@
|
||||
0,10.0229017899,7.30178495562,0.118115020017,1
|
||||
0,9.93639621859,9.93102159291,0.0435030004396,1
|
||||
0,10.1301737265,0.00411765220572,2.4165878053,1
|
||||
1,9.87828587087,0.608588414992,0.111262590883,1
|
||||
0,10.1373430048,0.47764012225,0.991553052194,1
|
||||
0,10.0523814718,4.72152505167,0.672978832666,1
|
||||
0,10.0449715742,8.40373928536,0.384457573667,1
|
||||
1,996.398498791,941.976309154,0.230269231292,2
|
||||
0,1005.11269468,900.093680877,0.265031528873,2
|
||||
0,997.160349441,891.331101688,2.19362017313,2
|
||||
0,993.754139031,44.8000165317,1.03868009875,2
|
||||
1,994.831299184,241.959208453,0.667631827024,2
|
||||
0,995.948333283,7.94326917112,0.750490877118,3
|
||||
0,989.733981273,7.52077625436,0.0126335967282,3
|
||||
0,1003.54086516,6.48177510564,1.19441696788,3
|
||||
0,996.56177804,9.71959812613,1.33082465111,3
|
||||
0,1005.61382467,0.234339369309,1.17987797356,3
|
||||
1,980.215758708,6.85554542926,2.63965085259,3
|
||||
1,987.776408872,2.23354609991,0.841885278028,3
|
||||
0,1006.54260396,8.12142049834,2.26639471174,3
|
||||
0,1009.87927639,6.40028519044,0.775155669615,3
|
||||
0,9.95006244393,928.76896718,234.948458244,4
|
||||
1,10.0749152258,255.294574476,62.9728604166,4
|
||||
1,10.1916541988,312.682867085,92.299413677,4
|
||||
0,9.95646724484,742.263188416,53.3310473654,4
|
||||
0,9.86211293222,996.237023866,2.00760301168,4
|
||||
1,9.91801019468,303.971783709,50.3147230679,4
|
||||
0,996.983996934,9.52188222766,1.33588120981,5
|
||||
0,995.704388126,9.49260524915,0.908498516541,5
|
||||
0,987.86480767,0.0870786716821,0.108859297837,5
|
||||
0,1000.99561307,2.85272694575,0.171134518956,5
|
||||
0,1011.05508066,7.55336771768,1.04950084825,5
|
||||
1,985.52199365,0.763305780608,1.7402424375,5
|
||||
0,10.0430321467,813.185427181,4.97728254185,6
|
||||
0,10.0812334228,258.297288417,0.127477670549,6
|
||||
0,9.84210504292,887.205815261,0.991689193955,6
|
||||
1,9.94625332613,0.298622762132,0.147881353231,6
|
||||
0,9.97800659954,727.619819757,0.0718361141866,6
|
||||
1,9.8037938472,957.385549617,0.0618862028941,6
|
||||
0,10.0880634741,185.024638577,1.7028095095,6
|
||||
0,9.98630799154,109.10631473,0.681117359751,6
|
||||
0,9.91671416638,166.248076588,122.538291094,7
|
||||
0,10.1206910464,88.1539468531,141.189859069,7
|
||||
1,10.1767160518,1.02960996847,172.02256237,7
|
||||
0,9.93025147233,391.196641942,58.040338247,7
|
||||
0,9.84850936037,474.63346537,17.5627875397,7
|
||||
1,9.8162731343,61.9199554213,30.6740972851,7
|
||||
0,10.0403482984,987.50416929,73.0472906209,7
|
||||
1,997.019228359,133.294717663,0.0572254083186,8
|
||||
0,973.303999107,1.79080888849,0.100478717048,8
|
||||
0,1008.28808825,342.282350685,0.409806485495,8
|
||||
0,1014.55621524,0.680510407082,0.929530602495,8
|
||||
1,1012.74370325,823.105266455,0.0894693730585,8
|
||||
0,1003.63554038,727.334432075,0.58206275756,8
|
||||
0,10.1560432436,740.35938307,11.6823378533,9
|
||||
0,9.83949099701,512.828227154,138.206666681,9
|
||||
1,10.1837395682,179.287126088,185.479062365,9
|
||||
1,9.9761881495,12.1093388336,9.1264604171,9
|
||||
1,9.77402180766,318.561317743,80.6005221355,9
|
||||
0,1011.15705381,0.215825852155,1.34429667906,10
|
||||
0,1005.60353229,727.202346126,1.47146041005,10
|
||||
1,1013.93702961,58.7312725205,0.421041560754,10
|
||||
0,1004.86813074,757.693204258,0.566055205344,10
|
||||
0,999.996324692,813.12386828,0.864428279513,10
|
||||
0,996.55255931,918.760056995,0.43365051974,10
|
||||
1,1004.1394132,464.371823646,0.312492288321,10
|
||||
|
149
jvm-packages/xgboost4j-spark/src/test/resources/rank.train.csv
Normal file
149
jvm-packages/xgboost4j-spark/src/test/resources/rank.train.csv
Normal file
@ -0,0 +1,149 @@
|
||||
0,985.574005058,320.223538037,0.621236086198,1
|
||||
0,1010.52917943,635.535543082,2.14984030531,1
|
||||
0,1012.91900422,132.387300057,0.488761066665,1
|
||||
0,990.829194034,135.102081162,0.747701610673,1
|
||||
0,1007.05103629,154.289183562,0.464118249201,1
|
||||
0,994.9573036,317.483732878,0.0313685555674,1
|
||||
0,987.8071541,731.349178363,0.244616944245,1
|
||||
1,10.0349544469,2.29750906143,36.4949974282,2
|
||||
0,9.92953881383,5.39134047297,120.041297548,2
|
||||
0,10.0909866713,9.06191026312,138.807825798,2
|
||||
1,10.2090970614,0.0784495944448,58.207703565,2
|
||||
0,9.85695905893,9.99500727713,56.8610243778,2
|
||||
1,10.0805758547,0.0410805760559,222.102302076,2
|
||||
0,10.1209914486,9.9729127088,171.888238763,2
|
||||
0,10.0331939798,0.853339303793,311.181328375,3
|
||||
0,9.93901762951,2.72757449146,78.4859514413,3
|
||||
0,10.0752365346,9.18695328235,49.8520256553,3
|
||||
1,10.0456548902,0.270936043122,123.462958597,3
|
||||
0,10.0568923673,0.82997113263,44.9391426001,3
|
||||
0,9.8214143472,0.277538931578,15.4217659578,3
|
||||
0,9.95258604431,8.69564346094,255.513470671,3
|
||||
0,9.91934976357,7.72809741413,82.171591817,3
|
||||
0,10.043239582,8.64168255553,38.9657919329,3
|
||||
1,10.0236147929,0.0496662263659,4.40889812286,3
|
||||
1,1001.85585324,3.75646886071,0.0179224994842,4
|
||||
0,1014.25578571,0.285765311201,0.510329864983,4
|
||||
1,1002.81422786,9.77676280375,0.433705951912,4
|
||||
1,998.072711553,2.82100686538,0.889829076909,4
|
||||
0,1003.77395036,2.55916592114,0.0359402151496,4
|
||||
1,10.0807877782,4.98513959013,47.5266363559,5
|
||||
0,10.0015013081,9.94302478763,78.3697486277,5
|
||||
1,10.0441936789,0.305091816635,56.8213984987,5
|
||||
0,9.94257106618,7.23909568913,442.463339039,5
|
||||
1,9.86479307916,6.41701315844,55.1365304834,5
|
||||
0,10.0428628516,9.98466447697,0.391632812588,5
|
||||
0,9.94445884566,9.99970945878,260.438436534,5
|
||||
1,9.84641392823,225.78051312,1.00525978847,6
|
||||
1,9.86907690608,26.8971083147,0.577959255991,6
|
||||
0,10.0177314626,0.110585342313,2.30545043031,6
|
||||
0,10.0688190907,412.023866234,1.22421542264,6
|
||||
0,10.1251769646,13.8212202925,0.129171734504,6
|
||||
0,10.0840758802,407.359097187,0.477000870705,6
|
||||
0,10.1007458705,987.183625145,0.149385677415,6
|
||||
0,9.86472656059,169.559640615,0.147221652519,6
|
||||
0,9.94207419238,507.290053755,0.41996207214,6
|
||||
0,9.9671005502,1.62610457716,0.408173666788,6
|
||||
0,1010.57126596,9.06673707562,0.672092284372,7
|
||||
0,1001.6718262,9.53203990055,4.7364050044,7
|
||||
0,995.777341384,4.43847316256,2.07229073634,7
|
||||
0,1002.95701386,5.51711016665,1.24294450546,7
|
||||
0,1016.0988238,0.626468941906,0.105627919134,7
|
||||
0,1013.67571419,0.042315529666,0.717619310322,7
|
||||
1,994.747747892,6.01989364024,0.772910130015,7
|
||||
1,991.654593872,7.35575736952,1.19822091548,7
|
||||
0,1008.47101732,8.28240754909,0.229582481359,7
|
||||
0,1000.81975227,1.52448354056,0.096441660362,7
|
||||
0,10.0900922344,322.656649307,57.8149073088,8
|
||||
1,10.0868337371,2.88652339174,54.8865514572,8
|
||||
0,10.0988984137,979.483832657,52.6809830901,8
|
||||
0,9.97678959238,665.770979738,481.069628909,8
|
||||
0,9.78554312773,257.309358658,47.7324475232,8
|
||||
0,10.0985967566,935.896512941,138.937052808,8
|
||||
0,10.0522252319,876.376299607,6.00373510669,8
|
||||
1,9.88065229501,9.99979825653,0.0674603696149,9
|
||||
0,10.0483244098,0.0653852316381,0.130679349938,9
|
||||
1,9.99685215607,1.76602542774,0.2551321159,9
|
||||
0,9.99750159428,1.01591534436,0.145445506504,9
|
||||
1,9.97380908941,0.940048645571,0.411805696316,9
|
||||
0,9.99977678382,6.91329929641,5.57858201258,9
|
||||
0,978.876096381,933.775364741,0.579170824236,10
|
||||
0,998.381016406,220.940470582,2.01491778565,10
|
||||
0,987.917644594,8.74667873567,0.364006099758,10
|
||||
0,1000.20994892,25.2945450565,3.5684398964,10
|
||||
0,1014.57141264,675.593540733,0.164174055535,10
|
||||
0,998.867283535,765.452750642,0.818425293238,10
|
||||
0,10.2143092481,273.576539531,137.111774354,11
|
||||
0,10.0366658918,842.469052609,2.32134375927,11
|
||||
0,10.1281202091,395.654057342,35.4184893063,11
|
||||
0,10.1443721289,960.058461049,272.887070637,11
|
||||
0,10.1353234784,535.51304462,2.15393842032,11
|
||||
1,10.0451640374,216.733858424,55.6533298016,11
|
||||
1,9.94254592171,44.5985537358,304.614176871,11
|
||||
0,10.1319257181,613.545504487,5.42391587912,11
|
||||
0,1020.63622468,997.476744201,0.509425590461,12
|
||||
0,986.304585519,822.669937965,0.605133561808,12
|
||||
1,1012.66863221,26.7185759069,0.0875458784828,12
|
||||
0,995.387656321,81.8540176995,0.691999430068,12
|
||||
0,1020.6587198,848.826964547,0.540159430526,12
|
||||
1,1003.81573853,379.84350931,0.0083682925194,12
|
||||
0,1021.60921516,641.376951467,1.12339054807,12
|
||||
0,1000.17585041,122.107138713,1.09906375372,12
|
||||
1,987.64802348,5.98448541152,0.124241987204,12
|
||||
1,9.94610136583,346.114985897,0.387708236565,13
|
||||
0,9.96812192337,313.278109696,0.00863026595671,13
|
||||
0,10.0181739194,36.7378924562,2.92179879835,13
|
||||
0,9.89000102695,164.273723971,0.685222591968,13
|
||||
0,10.1555212436,320.451459462,2.01341536261,13
|
||||
0,10.0085727613,999.767117646,0.462294934168,13
|
||||
1,9.93099658724,5.17478203909,0.213855205032,13
|
||||
0,10.0629454957,663.088181857,0.049022351462,13
|
||||
0,10.1109732417,734.904569784,1.6998450094,13
|
||||
0,1006.6015266,505.023453703,1.90870566777,14
|
||||
0,991.865769489,245.437343115,0.475109744256,14
|
||||
0,998.682734072,950.041057232,1.9256314201,14
|
||||
0,1005.02207209,2.9619314197,0.0517146822357,14
|
||||
0,1002.54526214,860.562681899,0.915687092848,14
|
||||
0,1000.38847359,808.416525088,0.209690673808,14
|
||||
1,992.557818382,373.889409453,0.107571728577,14
|
||||
0,1002.07722137,997.329626371,1.06504260496,14
|
||||
0,1000.40504333,949.832139189,0.539159980327,14
|
||||
0,10.1460179902,8.86082969819,135.953842715,15
|
||||
1,9.98529296553,2.87366448495,1.74249892194,15
|
||||
0,9.88942676744,9.4031821056,149.473066381,15
|
||||
1,10.0192953341,1.99685737576,1.79502473397,15
|
||||
0,10.0110654379,8.13112593726,87.7765628103,15
|
||||
0,997.148677047,733.936190093,1.49298494242,16
|
||||
0,1008.70465919,957.121652078,0.217414013634,16
|
||||
1,997.356154278,541.599587807,0.100855972216,16
|
||||
0,999.615897283,943.700501824,0.862874175879,16
|
||||
1,997.36859077,0.200859940848,0.13601892182,16
|
||||
0,10.0423255624,1.73855202168,0.956695338485,17
|
||||
1,9.88440755486,9.9994600678,0.305080529665,17
|
||||
0,10.0891026412,3.28031719474,0.364450973697,17
|
||||
0,9.90078644258,8.77839663617,0.456660574479,17
|
||||
1,9.79380029711,8.77220326156,0.527292005175,17
|
||||
0,9.93613887011,9.76270841268,1.40865693823,17
|
||||
0,10.0009239007,7.29056178263,0.498015866607,17
|
||||
0,9.96603319905,5.12498000925,0.517492532783,17
|
||||
0,10.0923827222,2.76652583955,1.56571226159,17
|
||||
1,10.0983782035,587.788120694,0.031756483687,18
|
||||
1,9.91397225464,994.527496819,3.72092164978,18
|
||||
0,10.1057472738,2.92894440088,0.683506438532,18
|
||||
0,10.1014053354,959.082038017,1.07039624129,18
|
||||
0,10.1433253044,322.515119317,0.51408278993,18
|
||||
1,9.82832510699,637.104433908,0.250272776427,18
|
||||
0,1000.49729075,2.75336888111,0.576634423274,19
|
||||
1,984.90338088,0.0295435794035,1.26273339929,19
|
||||
0,1001.53811442,4.64164410861,0.0293389959504,19
|
||||
1,995.875898395,5.08223403205,0.382330566779,19
|
||||
0,996.405937252,6.26395190757,0.453645816611,19
|
||||
0,10.0165140779,340.126072514,0.220794603312,20
|
||||
0,9.93482824816,951.672000448,0.124406293612,20
|
||||
0,10.1700278554,0.0140985961008,0.252452256311,20
|
||||
0,9.99825079542,950.382643896,0.875382402062,20
|
||||
0,9.87316410028,686.788257829,0.215886999825,20
|
||||
0,10.2893240654,89.3947931451,0.569578232133,20
|
||||
0,9.98689192703,0.430107535413,2.99869831728,20
|
||||
0,10.1365175107,972.279245093,0.0865099386744,20
|
||||
0,9.90744703306,50.810461183,3.00863325197,20
|
||||
|
@ -21,37 +21,27 @@ import java.nio.file.Files
|
||||
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
var sc: SparkContext = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
val conf: SparkConf = new SparkConf()
|
||||
.setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite")
|
||||
sc = new SparkContext(conf)
|
||||
}
|
||||
class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
private lazy val (model4, model8) = {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
(XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, nWorkers = sc.defaultParallelism),
|
||||
XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, nWorkers = sc.defaultParallelism))
|
||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
}
|
||||
|
||||
test("test update/load models") {
|
||||
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
manager.updateCheckpoint(model4)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
|
||||
|
||||
manager.updateCheckpoint(model8)
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
@ -61,7 +51,7 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
test("test cleanUpHigherVersions") {
|
||||
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
manager.updateCheckpoint(model8)
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.cleanUpHigherVersions(round = 8)
|
||||
assert(new File(s"$tmpPath/8.model").exists())
|
||||
|
||||
@ -74,7 +64,8 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
|
||||
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
|
||||
manager.updateCheckpoint(model4)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||
|
||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
|
||||
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
|
||||
|
||||
@transient private var currentSession: SparkSession = _
|
||||
@ -62,4 +64,30 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
file.delete()
|
||||
}
|
||||
}
|
||||
|
||||
protected def buildDataFrame(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
import DataUtils._
|
||||
val it = labeledPoints.iterator.zipWithIndex
|
||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||
(id, labeledPoint.label, labeledPoint.features)
|
||||
}
|
||||
|
||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||
.toDF("id", "label", "features")
|
||||
}
|
||||
|
||||
protected def buildDataFrameWithGroup(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
import DataUtils._
|
||||
val it = labeledPoints.iterator.zipWithIndex
|
||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
|
||||
}
|
||||
|
||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||
.toDF("id", "label", "features", "group")
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,167 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
import java.util.Arrays
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
|
||||
import scala.util.Random
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
private var tempDir: File = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
|
||||
tempDir = new File(System.getProperty("java.io.tmpdir"), this.getClass.getName)
|
||||
if (tempDir.exists) {
|
||||
tempDir.delete
|
||||
}
|
||||
tempDir.mkdirs
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
JavaUtils.deleteRecursively(tempDir)
|
||||
super.afterAll()
|
||||
}
|
||||
|
||||
private def delete(f: File) {
|
||||
if (f.exists) {
|
||||
if (f.isDirectory) {
|
||||
for (c <- f.listFiles) {
|
||||
delete(c)
|
||||
}
|
||||
}
|
||||
if (!f.delete) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
|
||||
val eval = new EvalError()
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
|
||||
val xgbc = new XGBoostClassifier(paramMap)
|
||||
val xgbcPath = new File(tempDir, "xgbc").getPath
|
||||
xgbc.write.overwrite().save(xgbcPath)
|
||||
val xgbc2 = XGBoostClassifier.load(xgbcPath)
|
||||
val paramMap2 = xgbc2.MLlib2XGBoostParams
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
||||
}
|
||||
|
||||
val model = xgbc.fit(trainingDF)
|
||||
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults < 0.1)
|
||||
val xgbcModelPath = new File(tempDir, "xgbcModel").getPath
|
||||
model.write.overwrite.save(xgbcModelPath)
|
||||
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
|
||||
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
||||
|
||||
assert(model.getEta === model2.getEta)
|
||||
assert(model.getNumRound === model2.getNumRound)
|
||||
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
|
||||
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults === evalResults2)
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostRegressor and XGBoostRegressionModel") {
|
||||
val eval = new EvalError()
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val testDM = new DMatrix(Regression.test.iterator)
|
||||
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> "10", "num_workers" -> numWorkers)
|
||||
val xgbr = new XGBoostRegressor(paramMap)
|
||||
val xgbrPath = new File(tempDir, "xgbr").getPath
|
||||
xgbr.write.overwrite().save(xgbrPath)
|
||||
val xgbr2 = XGBoostRegressor.load(xgbrPath)
|
||||
val paramMap2 = xgbr2.MLlib2XGBoostParams
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
||||
}
|
||||
|
||||
val model = xgbr.fit(trainingDF)
|
||||
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults < 0.1)
|
||||
val xgbrModelPath = new File(tempDir, "xgbrModel").getPath
|
||||
model.write.overwrite.save(xgbrModelPath)
|
||||
val model2 = XGBoostRegressionModel.load(xgbrModelPath)
|
||||
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
||||
|
||||
assert(model.getEta === model2.getEta)
|
||||
assert(model.getNumRound === model2.getNumRound)
|
||||
assert(model.getPredictionCol === model2.getPredictionCol)
|
||||
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults === evalResults2)
|
||||
}
|
||||
|
||||
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
|
||||
|
||||
val r = new Random(0)
|
||||
// maybe move to shared context, but requires session to import implicits
|
||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
|
||||
// Construct MLlib pipeline, save and load
|
||||
val pipeline = new Pipeline().setStages(Array(assembler, xgb))
|
||||
val pipePath = new File(tempDir, "pipeline").getPath
|
||||
pipeline.write.overwrite().save(pipePath)
|
||||
val pipeline2 = Pipeline.read.load(pipePath)
|
||||
val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
|
||||
val paramMap2 = xgb2.MLlib2XGBoostParams
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
||||
}
|
||||
|
||||
// Model training, save and load
|
||||
val pipeModel = pipeline.fit(df)
|
||||
val pipeModelPath = new File(tempDir, "pipelineModel").getPath
|
||||
pipeModel.write.overwrite.save(pipeModelPath)
|
||||
val pipeModel2 = PipelineModel.load(pipeModelPath)
|
||||
|
||||
val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel]
|
||||
val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel]
|
||||
|
||||
assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray))
|
||||
|
||||
assert(xgbModel.getEta === xgbModel2.getEta)
|
||||
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
|
||||
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,8 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
trait TrainTestData {
|
||||
@ -48,6 +48,17 @@ trait TrainTestData {
|
||||
XGBLabeledPoint(label, null, values)
|
||||
}.toList
|
||||
}
|
||||
|
||||
protected def getLabeledPointsWithGroup(resource: String): Seq[XGBLabeledPoint] = {
|
||||
getResourceLines(resource).map { line =>
|
||||
val original = line.split(",")
|
||||
val length = original.length
|
||||
val label = original.head.toFloat
|
||||
val group = original.last.toInt
|
||||
val values = original.slice(1, length - 1).map(_.toFloat)
|
||||
XGBLabeledPoint(label, null, values, 1f, group, Float.NaN)
|
||||
}.toList
|
||||
}
|
||||
}
|
||||
|
||||
object Classification extends TrainTestData {
|
||||
@ -80,11 +91,8 @@ object Regression extends TrainTestData {
|
||||
}
|
||||
|
||||
object Ranking extends TrainTestData {
|
||||
val train0: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-0.txt.train", zeroBased = false)
|
||||
val train1: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-1.txt.train", zeroBased = false)
|
||||
val trainGroup0: Seq[Int] = getGroups("/rank-demo-0.txt.train.group")
|
||||
val trainGroup1: Seq[Int] = getGroups("/rank-demo-1.txt.train.group")
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo.txt.test", zeroBased = false)
|
||||
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank.test.txt", zeroBased = false)
|
||||
|
||||
private def getGroups(resource: String): Seq[Int] = {
|
||||
getResourceLines(resource).map(_.toInt).toList
|
||||
|
||||
@ -0,0 +1,207 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
|
||||
test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") {
|
||||
val trainingDM = new DMatrix(Classification.train.iterator)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
val round = 5
|
||||
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
|
||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||
val prediction1 = model1.predict(testDM)
|
||||
|
||||
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
|
||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||
|
||||
val prediction2 = model2.transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
|
||||
|
||||
assert(testDF.count() === prediction2.size)
|
||||
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
|
||||
for (i <- prediction1.indices) {
|
||||
assert(prediction1(i).length === prediction2(i).values.length - 1)
|
||||
for (j <- prediction1(i).indices) {
|
||||
assert(prediction1(i)(j) === prediction2(i)(j + 1))
|
||||
}
|
||||
}
|
||||
|
||||
val prediction3 = model1.predict(testDM, outPutMargin = true)
|
||||
val prediction4 = model2.transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
||||
|
||||
assert(testDF.count() === prediction4.size)
|
||||
for (i <- prediction3.indices) {
|
||||
assert(prediction3(i).length === prediction4(i).values.length)
|
||||
for (j <- prediction3(i).indices) {
|
||||
assert(prediction3(i)(j) === prediction4(i)(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
val round = 5
|
||||
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> round,
|
||||
"num_workers" -> numWorkers)
|
||||
|
||||
// Set params in XGBoost way
|
||||
val model1 = new XGBoostClassifier(paramMap).fit(trainingDF)
|
||||
// Set params in MLlib way
|
||||
val model2 = new XGBoostClassifier()
|
||||
.setEta(1)
|
||||
.setMaxDepth(6)
|
||||
.setSilent(1)
|
||||
.setObjective("binary:logistic")
|
||||
.setNumRound(round)
|
||||
.setNumWorkers(numWorkers)
|
||||
.fit(trainingDF)
|
||||
|
||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
assert(p1 === p2)
|
||||
}
|
||||
}
|
||||
|
||||
test("test schema of XGBoostClassificationModel") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
|
||||
val model = new XGBoostClassifier(paramMap).fit(trainingDF)
|
||||
|
||||
model.setRawPredictionCol("raw_prediction")
|
||||
.setProbabilityCol("probability_prediction")
|
||||
.setPredictionCol("final_prediction")
|
||||
var predictionDF = model.transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("raw_prediction"))
|
||||
assert(predictionDF.columns.contains("probability_prediction"))
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
model.setRawPredictionCol("").setPredictionCol("final_prediction")
|
||||
predictionDF = model.transform(testDF)
|
||||
assert(predictionDF.columns.contains("raw_prediction") === false)
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
model.setRawPredictionCol("raw_prediction").setPredictionCol("")
|
||||
predictionDF = model.transform(testDF)
|
||||
assert(predictionDF.columns.contains("raw_prediction"))
|
||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
// from xgboost params to spark params
|
||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
||||
assert(xgb.getEta === 1.0)
|
||||
assert(xgb.getObjective === "binary:logistic")
|
||||
// from spark to xgboost params
|
||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
||||
}
|
||||
|
||||
test("multi class classification") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
||||
"num_workers" -> numWorkers)
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
assert(model.getEta == 0.1)
|
||||
assert(model.getMaxDepth == 6)
|
||||
assert(model.numClasses == 6)
|
||||
}
|
||||
|
||||
test("use base margin") {
|
||||
val training1 = buildDataFrame(Classification.train)
|
||||
val training2 = training1.withColumn("margin", functions.rand())
|
||||
val test = buildDataFrame(Classification.test)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "test_train_split" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model1 = xgb.fit(training1)
|
||||
val model2 = xgb.setBaseMarginCol("margin").fit(training2)
|
||||
val prediction1 = model1.transform(test).select(model1.getProbabilityCol)
|
||||
.collect().map(row => row.getAs[Vector](0))
|
||||
val prediction2 = model2.transform(test).select(model2.getProbabilityCol)
|
||||
.collect().map(row => row.getAs[Vector](0))
|
||||
var count = 0
|
||||
for ((r1, r2) <- prediction1.zip(prediction2)) {
|
||||
if (!r1.equals(r2)) count = count + 1
|
||||
}
|
||||
assert(count != 0)
|
||||
}
|
||||
|
||||
test("training summary") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
|
||||
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("train/test split") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
}
|
||||
}
|
||||
@ -17,36 +17,34 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||
|
||||
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
||||
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
|
||||
|
||||
test("nthread configuration must be no larger than spark.task.cpus") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
|
||||
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
|
||||
intercept[IllegalArgumentException] {
|
||||
XGBoost.trainWithRDD(sc.parallelize(List()), paramMap, 5, numWorkers)
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training)
|
||||
}
|
||||
}
|
||||
|
||||
test("kryoSerializer test") {
|
||||
import DataUtils._
|
||||
// TODO write an isolated test for Booster.
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator, null)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator, null)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val eval = new EvalError()
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,265 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DataTypes
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks
|
||||
|
||||
class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
|
||||
private def buildDataFrame(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
import DataUtils._
|
||||
val it = labeledPoints.iterator.zipWithIndex
|
||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||
(id, labeledPoint.label, labeledPoint.features)
|
||||
}
|
||||
|
||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||
.toDF("id", "label", "features")
|
||||
}
|
||||
|
||||
test("test consistency and order preservation of dataframe-based model") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val trainingItr = Classification.train.iterator
|
||||
val testItr = Classification.test.iterator
|
||||
val round = 5
|
||||
val trainDMatrix = new DMatrix(trainingItr)
|
||||
val testDMatrix = new DMatrix(testItr)
|
||||
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
|
||||
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = round, nWorkers = numWorkers)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
|
||||
// Spark
|
||||
for (i <- predResultFromSeq.indices) {
|
||||
assert(predResultFromSeq(i).length === predResultsFromDF(i).values.length - 1)
|
||||
for (j <- predResultFromSeq(i).indices) {
|
||||
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j + 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("test transformLeaf") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
xgBoostModelWithDF.transformLeaf(testDF).show()
|
||||
}
|
||||
|
||||
test("test schema of XGBoostRegressionModel") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
||||
xgBoostModelWithDF.setPredictionCol("final_prediction")
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
predictionDF.show()
|
||||
}
|
||||
|
||||
test("test schema of XGBoostClassificationModel") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(
|
||||
"raw_prediction").setPredictionCol("final_prediction")
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("raw_prediction"))
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("").
|
||||
setPredictionCol("final_prediction")
|
||||
predictionDF = xgBoostModelWithDF.transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("raw_prediction") === false)
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].
|
||||
setRawPredictionCol("raw_prediction").setPredictionCol("")
|
||||
predictionDF = xgBoostModelWithDF.transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("raw_prediction"))
|
||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||
}
|
||||
|
||||
test("xgboost and spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
// from xgboost params to spark params
|
||||
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
|
||||
assert(xgbEstimator.get(xgbEstimator.eta).get === 1.0)
|
||||
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
|
||||
// from spark to xgboost params
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eta").toString.toDouble === 1.0)
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("objective").toString === "binary:logistic")
|
||||
}
|
||||
|
||||
test("eval_metric is configured correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
|
||||
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
|
||||
val sparkParamMap = ParamMap.empty
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eval_metric") === "error")
|
||||
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
|
||||
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
|
||||
}
|
||||
|
||||
ignore("fast histogram algorithm parameters are exposed correctly") {
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
"eval_metric" -> "error")
|
||||
val testItr = Classification.test.iterator
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 10, nWorkers = math.min(2, numWorkers))
|
||||
val error = new EvalError
|
||||
val testSetDMatrix = new DMatrix(testItr)
|
||||
assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
}
|
||||
|
||||
test("multi_class classification test") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
XGBoost.trainWithDataFrame(trainingDF.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
||||
}
|
||||
|
||||
test("test DF use nested groupData") {
|
||||
val trainingDF = buildDataFrame(Ranking.train0, 1)
|
||||
.union(buildDataFrame(Ranking.train1, 1))
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = 2)
|
||||
val testDF = buildDataFrame(Ranking.test)
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
}
|
||||
|
||||
test("params of estimator and produced model are coordinated correctly") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers)
|
||||
assert(model.get[Double](model.eta).get == 0.1)
|
||||
assert(model.get[Int](model.maxDepth).get == 6)
|
||||
assert(model.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||
}
|
||||
|
||||
test("test use base margin") {
|
||||
import DataUtils._
|
||||
val trainingDf = buildDataFrame(Classification.train)
|
||||
val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand())
|
||||
val testRDD = sc.parallelize(Classification.test.map(_.features))
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "baseMarginCol" -> "margin",
|
||||
"testTrainSplit" -> 0.5)
|
||||
|
||||
def trainPredict(df: Dataset[_]): Array[Float] = {
|
||||
XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers)
|
||||
.predict(testRDD)
|
||||
.map { case Array(p) => p }
|
||||
.collect()
|
||||
}
|
||||
|
||||
val pred = trainPredict(trainingDf)
|
||||
val predWithMargin = trainPredict(trainingDfWithMargin)
|
||||
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
|
||||
}
|
||||
|
||||
test("test use weight") {
|
||||
import DataUtils._
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "weightCol" -> "weight")
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
.withColumn("weight", getWeightFromId(col("id")))
|
||||
|
||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
.setPredictionCol("final_prediction")
|
||||
.setExternalMemory(true)
|
||||
val testRDD = sc.parallelize(Regression.test.map(_.features))
|
||||
val predictions = model.predict(testRDD).collect().flatten
|
||||
|
||||
// The predictions heavily relies on the first training instance, and thus are very close.
|
||||
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
|
||||
}
|
||||
|
||||
test("training summary") {
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
|
||||
val trainingDf = buildDataFrame(Classification.train)
|
||||
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("train/test split") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
|
||||
val trainingDf = buildDataFrame(Classification.train)
|
||||
|
||||
forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
|
||||
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = useExternalMemory)
|
||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -18,19 +18,18 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import scala.util.Random
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
val vectorLength = 100
|
||||
val rdd = sc.parallelize(
|
||||
@ -87,283 +86,153 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
}
|
||||
|
||||
test("training with external memory cache") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"use_external_memory" -> true)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
|
||||
ignore("test with fast histo depthwise") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "eval_metric" -> "error")
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
|
||||
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
|
||||
// TODO: histogram algorithm seems to be very very sensitive to worker number
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo lossguide") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
|
||||
"max_leaves" -> "8", "eval_metric" -> "error", "num_round" -> 5,
|
||||
"num_workers" -> math.min(numWorkers, 2))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo lossguide with max bin") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||
"eval_metric" -> "error")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo depthwidth with max depth") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
||||
"eval_metric" -> "error")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo depthwidth with max depth and max bin") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
"eval_metric" -> "error")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
test("test with dense vectors containing missing value") {
|
||||
def buildDenseRDD(): RDD[MLLabeledPoint] = {
|
||||
test("dense vectors containing missing value") {
|
||||
def buildDenseDataFrame(): DataFrame = {
|
||||
val numRows = 100
|
||||
val numCols = 5
|
||||
|
||||
val labeledPoints = (0 until numRows).map { _ =>
|
||||
val label = Random.nextDouble()
|
||||
val data = (0 until numRows).map { x =>
|
||||
val label = Random.nextInt(2)
|
||||
val values = Array.tabulate[Double](numCols) { c =>
|
||||
if (c == numCols - 1) -0.1 else Random.nextDouble()
|
||||
if (c == numCols - 1) -0.1 else Random.nextDouble
|
||||
}
|
||||
|
||||
MLLabeledPoint(label, Vectors.dense(values))
|
||||
(label, Vectors.dense(values))
|
||||
}
|
||||
|
||||
sc.parallelize(labeledPoints)
|
||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
||||
}
|
||||
|
||||
val trainingRDD = buildDenseRDD().repartition(4)
|
||||
val testRDD = buildDenseRDD().repartition(4).map(_.features.asInstanceOf[DenseVector])
|
||||
val denseDF = buildDenseDataFrame().repartition(4)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers,
|
||||
useExternalMemory = true)
|
||||
xgBoostModel.predict(testRDD, missingValue = -0.1f).collect()
|
||||
}
|
||||
|
||||
test("test consistency of prediction functions with RDD") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSet = Classification.test
|
||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||
val testCollection = testRDD.collect()
|
||||
for (i <- testSet.indices) {
|
||||
assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values))
|
||||
}
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1 = predRDD.collect()
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
|
||||
for (i <- predResult1.indices; j <- predResult1(i).indices) {
|
||||
assert(predResult1(i)(j) === predResult2(i)(j))
|
||||
}
|
||||
}
|
||||
|
||||
test("test eval functions with RDD") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers)
|
||||
// Nan Zhu: deprecate it for now
|
||||
// xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
|
||||
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false)
|
||||
}
|
||||
|
||||
test("test prediction functionality with empty partition") {
|
||||
import DataUtils._
|
||||
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
|
||||
sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers)
|
||||
}
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testRDD = buildEmptyRDD()
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
||||
}
|
||||
|
||||
test("test use groupData") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Ranking.train0, numSlices = 1).map(_.asML)
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0)
|
||||
val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "eval_metric" -> "ndcg", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 2, nWorkers = 1)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1: Array[Array[Float]] = predRDD.collect()
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
|
||||
val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData)
|
||||
assert(avgMetric contains "ndcg")
|
||||
// If the labels were lost ndcg comes back as 1.0
|
||||
assert(avgMetric.split('=')(1).toFloat < 1F)
|
||||
}
|
||||
|
||||
test("test use nested groupData") {
|
||||
import DataUtils._
|
||||
val trainingRDD0 = sc.parallelize(Ranking.train0, numSlices = 1)
|
||||
val trainingRDD1 = sc.parallelize(Ranking.train1, numSlices = 1)
|
||||
val trainingRDD = trainingRDD0.union(trainingRDD1).map(_.asML)
|
||||
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
|
||||
|
||||
val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1: Array[Array[Float]] = predRDD.collect()
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
||||
model.transform(denseDF).collect()
|
||||
}
|
||||
|
||||
test("training with spark parallelism checks disabled") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
}
|
||||
|
||||
test("isClassificationTask correctly classifies supported objectives") {
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks._
|
||||
|
||||
val objectives = Table(
|
||||
("isClassificationTask", "params"),
|
||||
(true, Map("obj_type" -> "classification")),
|
||||
(false, Map("obj_type" -> "regression")),
|
||||
(false, Map("objective" -> "rank:ndcg")),
|
||||
(false, Map("objective" -> "rank:pairwise")),
|
||||
(false, Map("objective" -> "rank:map")),
|
||||
(false, Map("objective" -> "count:poisson")),
|
||||
(true, Map("objective" -> "binary:logistic")),
|
||||
(true, Map("objective" -> "binary:logitraw")),
|
||||
(true, Map("objective" -> "multi:softmax")),
|
||||
(true, Map("objective" -> "multi:softprob")),
|
||||
(false, Map("objective" -> "reg:linear")),
|
||||
(false, Map("objective" -> "reg:logistic")),
|
||||
(false, Map("objective" -> "reg:gamma")),
|
||||
(false, Map("objective" -> "reg:tweedie")))
|
||||
forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) =>
|
||||
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
|
||||
}
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L,
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
test("training with checkpoint boosters") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2).toMap
|
||||
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
def error(model: XGBoostModel): Float = eval.eval(
|
||||
model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix)
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
||||
|
||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
def error(model: Booster): Float = eval.eval(
|
||||
model.predict(testDM, outPutMargin = true), testDM)
|
||||
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
|
||||
// Train next model based on prev model
|
||||
val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8,
|
||||
nWorkers = numWorkers)
|
||||
assert(error(tmpModel) > error(prevModel))
|
||||
assert(error(prevModel) > error(nextModel))
|
||||
assert(error(nextModel) < 0.1)
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) > error(prevModel._booster))
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,133 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostModelSuite extends FunSuite with PerTest {
|
||||
test("test model consistency after save and load") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
assert(evalResults < 0.1)
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
|
||||
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
|
||||
assert(loadedEvalResults == evalResults)
|
||||
}
|
||||
|
||||
test("test save and load of different types of models") {
|
||||
import DataUtils._
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
// validate regression model
|
||||
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = false)
|
||||
xgBoostModel.setFeaturesCol("feature_col")
|
||||
xgBoostModel.setLabelCol("label_col")
|
||||
xgBoostModel.setPredictionCol("prediction_col")
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
|
||||
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
|
||||
assert(loadedXGBoostModel.getLabelCol == "label_col")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
|
||||
// classification model
|
||||
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = false)
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||
"raw_col")
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||
Array(0.5, 0.5).deep)
|
||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||
// (multiclass) classification model
|
||||
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
|
||||
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = false)
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
|
||||
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||
"raw_col")
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||
}
|
||||
|
||||
test("copy and predict ClassificationModel") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testRDD = sc.parallelize(Classification.test).map(_.features)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
testCopy(model, testRDD)
|
||||
}
|
||||
|
||||
test("copy and predict RegressionModel") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Regression.train).map(_.asML)
|
||||
val testRDD = sc.parallelize(Regression.test).map(_.features)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
testCopy(model, testRDD)
|
||||
}
|
||||
|
||||
private def testCopy(model: XGBoostModel, testRDD: RDD[Vector]): Unit = {
|
||||
val modelCopy = model.copy(ParamMap.empty)
|
||||
modelCopy.summary // Ensure no exception.
|
||||
|
||||
val expected = model.predict(testRDD).collect
|
||||
assert(modelCopy.predict(testRDD).collect === expected)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,114 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.types._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
|
||||
test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") {
|
||||
val trainingDM = new DMatrix(Regression.train.iterator)
|
||||
val testDM = new DMatrix(Regression.test.iterator)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val round = 5
|
||||
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
|
||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||
val prediction1 = model1.predict(testDM)
|
||||
|
||||
val model2 = new XGBoostRegressor(paramMap ++ Array("num_round" -> round,
|
||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||
|
||||
val prediction2 = model2.transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
|
||||
|
||||
assert(prediction1.indices.count { i =>
|
||||
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
|
||||
} < prediction1.length * 0.1)
|
||||
}
|
||||
|
||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val round = 5
|
||||
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "reg:linear",
|
||||
"num_round" -> round,
|
||||
"num_workers" -> numWorkers)
|
||||
|
||||
// Set params in XGBoost way
|
||||
val model1 = new XGBoostRegressor(paramMap).fit(trainingDF)
|
||||
// Set params in MLlib way
|
||||
val model2 = new XGBoostRegressor()
|
||||
.setEta(1)
|
||||
.setMaxDepth(6)
|
||||
.setSilent(1)
|
||||
.setObjective("reg:linear")
|
||||
.setNumRound(round)
|
||||
.setNumWorkers(numWorkers)
|
||||
.fit(trainingDF)
|
||||
|
||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
assert(math.abs(p1 - p2) <= 0.01f)
|
||||
}
|
||||
}
|
||||
|
||||
test("ranking: use group data") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
|
||||
"group_col" -> "group")
|
||||
|
||||
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
||||
val testDF = buildDataFrame(Ranking.test)
|
||||
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
|
||||
|
||||
val prediction = model.transform(testDF).collect()
|
||||
assert(testDF.count() === prediction.length)
|
||||
}
|
||||
|
||||
test("use weight") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
.withColumn("weight", getWeightFromId(col("id")))
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
|
||||
val model = new XGBoostRegressor(paramMap).setWeightCol("weight").fit(trainingDF)
|
||||
val prediction = model.transform(testDF).collect()
|
||||
val first = prediction.head.getAs[Double]("prediction")
|
||||
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
|
||||
}
|
||||
}
|
||||
@ -1,138 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
class XGBoostSparkPipelinePersistence extends FunSuite with PerTest
|
||||
with BeforeAndAfterAll {
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
delete(new File("./testxgbPipe"))
|
||||
delete(new File("./testxgbEst"))
|
||||
delete(new File("./testxgbModel"))
|
||||
delete(new File("./test2xgbModel"))
|
||||
}
|
||||
|
||||
private def delete(f: File) {
|
||||
if (f.exists()) {
|
||||
if (f.isDirectory()) {
|
||||
for (c <- f.listFiles()) {
|
||||
delete(c)
|
||||
}
|
||||
}
|
||||
if (!f.delete()) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostEstimator") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||
xgbEstimator.write.overwrite().save("./testxgbEst")
|
||||
val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst")
|
||||
val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||
}
|
||||
}
|
||||
|
||||
test("test persistence of a complete pipeline") {
|
||||
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val r = new Random(0)
|
||||
val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features")
|
||||
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||
val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator))
|
||||
pipeline.write.overwrite().save("testxgbPipe")
|
||||
val loadedPipeline = Pipeline.read.load("testxgbPipe")
|
||||
val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator]
|
||||
val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||
}
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostModel") {
|
||||
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||
val r = new Random(0)
|
||||
// maybe move to shared context, but requires session to import implicits
|
||||
val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(df.columns
|
||||
.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
val xgbEstimator = new XGBoostEstimator(Map("num_round" -> 10,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")
|
||||
)).setFeaturesCol("features").setLabelCol("label")
|
||||
// separate
|
||||
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
|
||||
predModel.write.overwrite.save("test2xgbModel")
|
||||
val same2Model = XGBoostModel.load("test2xgbModel")
|
||||
|
||||
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
|
||||
val predParamMap = predModel.extractParamMap()
|
||||
val same2ParamMap = same2Model.extractParamMap()
|
||||
assert(predParamMap.get(predModel.useExternalMemory)
|
||||
=== same2ParamMap.get(same2Model.useExternalMemory))
|
||||
assert(predParamMap.get(predModel.featuresCol) === same2ParamMap.get(same2Model.featuresCol))
|
||||
assert(predParamMap.get(predModel.predictionCol)
|
||||
=== same2ParamMap.get(same2Model.predictionCol))
|
||||
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
|
||||
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
|
||||
|
||||
// chained
|
||||
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
|
||||
predictionModel.write.overwrite.save("testxgbModel")
|
||||
val sameModel = PipelineModel.load("testxgbModel")
|
||||
|
||||
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||
|
||||
assert(java.util.Arrays.equals(
|
||||
predictionModelXGB.booster.toByteArray,
|
||||
sameModelXGB.booster.toByteArray
|
||||
))
|
||||
val predictionModelXGBParamMap = predictionModel.extractParamMap()
|
||||
val sameModelXGBParamMap = sameModel.extractParamMap()
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.useExternalMemory)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.useExternalMemory))
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.featuresCol)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.featuresCol))
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.predictionCol)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.predictionCol))
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,6 +197,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
booster.getFeatureScore(featureMap).asScala
|
||||
}
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
|
||||
def toByteArray: Array[Byte] = {
|
||||
booster.toByteArray
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user