From 2c4359e914ebdd76d5c22619cdad7213d29fa206 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 18 Jun 2018 15:39:18 -0700 Subject: [PATCH] [jvm-packages] XGBoost Spark integration refactor (#3387) * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * [jvm-packages] XGBoost Spark integration refactor. (#3313) * XGBoost Spark integration refactor. * Make corresponding update for xgboost4j-example * Address comments. * [jvm-packages] Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib (#3326) * Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib * Fix extra space. * [jvm-packages] XGBoost Spark supports ranking with group data. (#3369) * XGBoost Spark supports ranking with group data. * Use Iterator.duplicate to prevent OOM. * Update CheckpointManagerSuite.scala * Resolve conflicts --- .../example/spark/SparkModelTuningTool.scala | 6 +- .../example/spark/SparkWithDataFrame.scala | 9 +- .../scala/example/spark/SparkWithRDD.scala | 58 --- .../scala/spark/CheckpointManager.scala | 15 +- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 241 ++-------- .../spark/XGBoostClassificationModel.scala | 181 -------- .../scala/spark/XGBoostClassifier.scala | 432 ++++++++++++++++++ .../scala/spark/XGBoostEstimator.scala | 186 -------- .../xgboost4j/scala/spark/XGBoostModel.scala | 387 ---------------- .../scala/spark/XGBoostRegressionModel.scala | 61 --- .../scala/spark/XGBoostRegressor.scala | 356 +++++++++++++++ .../scala/spark/params/BoosterParams.scala | 142 +++--- .../scala/spark/params/GeneralParams.scala | 142 +++++- .../spark/params/LearningTaskParams.scala | 54 +-- .../src/test/resources/rank-demo-0.txt.train | 75 --- .../resources/rank-demo-0.txt.train.group | 10 - .../src/test/resources/rank-demo-1.txt.train | 74 --- .../resources/rank-demo-1.txt.train.group | 10 - .../test/resources/rank-demo.txt.test.group | 10 - .../src/test/resources/rank.test.csv | 66 +++ .../{rank-demo.txt.test => rank.test.txt} | 0 .../src/test/resources/rank.train.csv | 149 ++++++ .../scala/spark/CheckpointManagerSuite.scala | 29 +- .../dmlc/xgboost4j/scala/spark/PerTest.scala | 30 +- .../scala/spark/PersistenceSuite.scala | 167 +++++++ .../xgboost4j/scala/spark/TrainTestData.scala | 20 +- .../scala/spark/XGBoostClassifierSuite.scala | 207 +++++++++ .../scala/spark/XGBoostConfigureSuite.scala | 24 +- .../scala/spark/XGBoostDFSuite.scala | 265 ----------- .../scala/spark/XGBoostGeneralSuite.scala | 301 ++++-------- .../scala/spark/XGBoostModelSuite.scala | 133 ------ .../scala/spark/XGBoostRegressorSuite.scala | 114 +++++ .../XGBoostSparkPipelinePersistence.scala | 138 ------ .../ml/dmlc/xgboost4j/scala/Booster.scala | 2 + 34 files changed, 1921 insertions(+), 2173 deletions(-) delete mode 100644 jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassificationModel.scala create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressionModel.scala create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train delete mode 100644 jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train.group delete mode 100644 jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train delete mode 100644 jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train.group delete mode 100644 jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test.group create mode 100644 jvm-packages/xgboost4j-spark/src/test/resources/rank.test.csv rename jvm-packages/xgboost4j-spark/src/test/resources/{rank-demo.txt.test => rank.test.txt} (100%) create mode 100644 jvm-packages/xgboost4j-spark/src/test/resources/rank.train.csv create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModelSuite.scala create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala delete mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkModelTuningTool.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkModelTuningTool.scala index c200a4ee4..0c4a7ce14 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkModelTuningTool.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkModelTuningTool.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.io.Source -import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoost} +import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor import org.apache.spark.ml.Pipeline import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer} @@ -160,10 +160,10 @@ object SparkModelTuningTool { private def crossValidation( xgboostParam: Map[String, Any], trainingData: Dataset[_]): TrainValidationSplitModel = { - val xgbEstimator = new XGBoostEstimator(xgboostParam).setFeaturesCol("features"). + val xgbEstimator = new XGBoostRegressor(xgboostParam).setFeaturesCol("features"). setLabelCol("logSales") val paramGrid = new ParamGridBuilder() - .addGrid(xgbEstimator.round, Array(20, 50)) + .addGrid(xgbEstimator.numRound, Array(20, 50)) .addGrid(xgbEstimator.eta, Array(0.1, 0.4)) .build() val tv = new TrainValidationSplit() diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala index c2efcc6fe..d02ba2fbd 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithDataFrame.scala @@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala.example.spark import ml.dmlc.xgboost4j.scala.Booster -import ml.dmlc.xgboost4j.scala.spark.XGBoost +import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier import org.apache.spark.sql.SparkSession import org.apache.spark.SparkConf @@ -45,9 +45,10 @@ object SparkWithDataFrame { val paramMap = List( "eta" -> 0.1f, "max_depth" -> 2, - "objective" -> "binary:logistic").toMap - val xgboostModel = XGBoost.trainWithDataFrame( - trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true) + "objective" -> "binary:logistic", + "num_round" -> numRound, + "nWorkers" -> args(1).toInt).toMap + val xgboostModel = new XGBoostClassifier(paramMap).fit(trainDF) // xgboost-spark appends the column containing prediction results xgboostModel.transform(testDF).show() } diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala deleted file mode 100644 index ed3d54fa3..000000000 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.example.spark - -import ml.dmlc.xgboost4j.scala.Booster -import ml.dmlc.xgboost4j.scala.spark.XGBoost - -import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} -import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.{SparkConf, SparkContext} - -object SparkWithRDD { - def main(args: Array[String]): Unit = { - if (args.length != 5) { - println( - "usage: program num_of_rounds num_workers training_path test_path model_path") - sys.exit(1) - } - val sparkConf = new SparkConf().setAppName("XGBoost-spark-example") - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sparkConf.registerKryoClasses(Array(classOf[Booster])) - implicit val sc = new SparkContext(sparkConf) - val inputTrainPath = args(2) - val inputTestPath = args(3) - val outputModelPath = args(4) - // number of iterations - val numRound = args(0).toInt - val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).map(lp => - MLLabeledPoint(lp.label, new MLDenseVector(lp.features.toArray))) - val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath) - .map(lp => new MLDenseVector(lp.features.toArray)) - // training parameters - val paramMap = List( - "eta" -> 0.1f, - "max_depth" -> 2, - "objective" -> "binary:logistic").toMap - val xgboostModel = XGBoost.trainWithRDD(trainRDD, paramMap, numRound, nWorkers = args(1).toInt, - useExternalMemory = true) - xgboostModel.predict(testSet, missingValue = Float.NaN) - // save model to HDFS path - xgboostModel.saveModelAsHadoopFile(outputModelPath) - } -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala index 3756c152c..bcdeee5c4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala @@ -17,6 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.Booster +import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost} import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext @@ -63,9 +64,9 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) val version = versions.max val fullPath = getPath(version) logger.info(s"Start training from previous booster at $fullPath") - val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc) - model.booster.booster.setVersion(version) - model.booster + val booster = SXGBoost.loadModel(fullPath) + booster.booster.setVersion(version) + booster } else { null } @@ -76,12 +77,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) * * @param checkpoint the checkpoint to save as an XGBoostModel */ - private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = { + private[spark] def updateCheckpoint(checkpoint: Booster): Unit = { val fs = FileSystem.get(sc.hadoopConfiguration) val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version))) - val fullPath = getPath(checkpoint.version) - logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath") - checkpoint.saveModelAsHadoopFile(fullPath)(sc) + val fullPath = getPath(checkpoint.getVersion) + logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath") + checkpoint.saveModel(fullPath) prevModelPaths.foreach(path => fs.delete(path, true)) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 8e4426552..f8ab10cb3 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -21,16 +21,15 @@ import java.nio.file.Files import scala.collection.mutable import scala.util.Random + import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} + import org.apache.commons.io.FileUtils import org.apache.commons.logging.LogFactory -import org.apache.hadoop.fs.{FSDataInputStream, Path} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dataset -import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext} @@ -134,7 +133,7 @@ object XGBoost extends Serializable { fromBaseMarginsToArray(baseMargins), cacheDirName) try { - val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds") + val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") .map(_.toString.toInt).getOrElse(0) val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) val booster = SXGBoost.train(watches.train, params, round, @@ -148,89 +147,6 @@ object XGBoost extends Serializable { }.cache() } - /** - * Train XGBoost model with the DataFrame-represented data - * - * @param trainingData the training set represented as DataFrame - * @param params Map containing the parameters to configure XGBoost - * @param round the number of iterations - * @param nWorkers the number of xgboost workers, 0 by default which means that the number of - * workers equals to the partition number of trainingData RDD - * @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default - * @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default - * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as - * true, the user may save the RAM cost for running XGBoost within Spark - * @param missing The value which represents a missing value in the dataset - * @param featureCol the name of input column, "features" as default value - * @param labelCol the name of output column, "label" as default value - * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed - * @return XGBoostModel when successful training - */ - @throws(classOf[XGBoostError]) - def trainWithDataFrame( - trainingData: Dataset[_], - params: Map[String, Any], - round: Int, - nWorkers: Int, - obj: ObjectiveTrait = null, - eval: EvalTrait = null, - useExternalMemory: Boolean = false, - missing: Float = Float.NaN, - featureCol: String = "features", - labelCol: String = "label"): XGBoostModel = { - require(nWorkers > 0, "you must specify more than 0 workers") - val estimator = new XGBoostEstimator(params) - // assigning general parameters - estimator. - set(estimator.useExternalMemory, useExternalMemory). - set(estimator.round, round). - set(estimator.nWorkers, nWorkers). - set(estimator.customObj, obj). - set(estimator.customEval, eval). - set(estimator.missing, missing). - setFeaturesCol(featureCol). - setLabelCol(labelCol). - fit(trainingData) - } - - private[spark] def isClassificationTask(params: Map[String, Any]): Boolean = { - val objective = params.getOrElse("objective", params.getOrElse("obj_type", null)) - objective != null && { - val objStr = objective.toString - objStr != "regression" && !objStr.startsWith("reg:") && objStr != "count:poisson" && - !objStr.startsWith("rank:") - } - } - - /** - * Train XGBoost model with the RDD-represented data - * - * @param trainingData the training set represented as RDD - * @param params Map containing the configuration entries - * @param round the number of iterations - * @param nWorkers the number of xgboost workers, 0 by default which means that the number of - * workers equals to the partition number of trainingData RDD - * @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default - * @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default - * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as - * true, the user may save the RAM cost for running XGBoost within Spark - * @param missing the value represented the missing value in the dataset - * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed - * @return XGBoostModel when successful training - */ - @deprecated("Use XGBoost.trainWithRDD instead.") - def train( - trainingData: RDD[MLLabeledPoint], - params: Map[String, Any], - round: Int, - nWorkers: Int, - obj: ObjectiveTrait = null, - eval: EvalTrait = null, - useExternalMemory: Boolean = false, - missing: Float = Float.NaN): XGBoostModel = { - trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing) - } - private def overrideParamsAccordingToTaskCPUs( params: Map[String, Any], sc: SparkContext): Map[String, Any] = { @@ -259,39 +175,8 @@ object XGBoost extends Serializable { } /** - * Train XGBoost model with the RDD-represented data - * - * @param trainingData the training set represented as RDD - * @param params Map containing the configuration entries - * @param round the number of iterations - * @param nWorkers the number of xgboost workers, 0 by default which means that the number of - * workers equals to the partition number of trainingData RDD - * @param obj An instance of [[ObjectiveTrait]] specifying a custom objective, null by default - * @param eval An instance of [[EvalTrait]] specifying a custom evaluation metric, null by default - * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as - * true, the user may save the RAM cost for running XGBoost within Spark - * @param missing The value which represents a missing value in the dataset - * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training has failed - * @return XGBoostModel when successful training + * @return A tuple of the booster and the metrics used to build training summary */ - @throws(classOf[XGBoostError]) - def trainWithRDD( - trainingData: RDD[MLLabeledPoint], - params: Map[String, Any], - round: Int, - nWorkers: Int, - obj: ObjectiveTrait = null, - eval: EvalTrait = null, - useExternalMemory: Boolean = false, - missing: Float = Float.NaN): XGBoostModel = { - import DataUtils._ - val xgbTrainingData = trainingData.map { case MLLabeledPoint(label, features) => - features.asXGB.copy(label = label.toFloat) - } - trainDistributed(xgbTrainingData, params, round, nWorkers, obj, eval, - useExternalMemory, missing) - } - @throws(classOf[XGBoostError]) private[spark] def trainDistributed( trainingData: RDD[XGBLabeledPoint], @@ -301,7 +186,7 @@ object XGBoost extends Serializable { obj: ObjectiveTrait = null, eval: EvalTrait = null, useExternalMemory: Boolean = false, - missing: Float = Float.NaN): XGBoostModel = { + missing: Float = Float.NaN): (Booster, Map[String, Array[Float]]) = { if (params.contains("tree_method")) { require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" + " for now") @@ -350,20 +235,15 @@ object XGBoost extends Serializable { } sparkJobThread.setUncaughtExceptionHandler(tracker) sparkJobThread.start() - val isClsTask = isClassificationTask(params) val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) logger.info(s"Rabit returns with exit code $trackerReturnVal") - val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics, - sparkJobThread, isClsTask) - if (isClsTask){ - model.asInstanceOf[XGBoostClassificationModel].numOfClasses = - params.getOrElse("num_class", "2").toString.toInt - } + val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics, + sparkJobThread) if (checkpointRound < round) { - prevBooster = model.booster - checkpointManager.updateCheckpoint(model) + prevBooster = booster + checkpointManager.updateCheckpoint(prevBooster) } - model + (booster, metrics) } finally { tracker.stop() } @@ -383,17 +263,14 @@ object XGBoost extends Serializable { private def postTrackerReturnProcessing( trackerReturnVal: Int, distributedBoostersAndMetrics: RDD[(Booster, Map[String, Array[Float]])], - sparkJobThread: Thread, - isClassificationTask: Boolean - ): XGBoostModel = { + sparkJobThread: Thread): (Booster, Map[String, Array[Float]]) = { if (trackerReturnVal == 0) { // Copies of the final booster and the corresponding metrics // reside in each partition of the `distributedBoostersAndMetrics`. // Any of them can be used to create the model. val (booster, metrics) = distributedBoostersAndMetrics.first() - val xgboostModel = XGBoostModel(booster, isClassificationTask) distributedBoostersAndMetrics.unpersist(false) - xgboostModel.setSummary(XGBoostTrainingSummary(metrics)) + (booster, metrics) } else { try { if (sparkJobThread.isAlive) { @@ -407,64 +284,6 @@ object XGBoost extends Serializable { } } - private def loadGeneralModelParams(inputStream: FSDataInputStream): (String, String, String) = { - val featureCol = inputStream.readUTF() - val labelCol = inputStream.readUTF() - val predictionCol = inputStream.readUTF() - (featureCol, labelCol, predictionCol) - } - - private def setGeneralModelParams( - featureCol: String, - labelCol: String, - predCol: String, - xgBoostModel: XGBoostModel): XGBoostModel = { - xgBoostModel.setFeaturesCol(featureCol) - xgBoostModel.setLabelCol(labelCol) - xgBoostModel.setPredictionCol(predCol) - } - - - /** - * Load XGBoost model from path in HDFS-compatible file system - * - * @param modelPath The path of the file representing the model - * @return The loaded model - */ - def loadModelFromHadoopFile(modelPath: String)(implicit sparkContext: SparkContext): - XGBoostModel = { - val path = new Path(modelPath) - val dataInStream = path.getFileSystem(sparkContext.hadoopConfiguration).open(path) - val modelType = dataInStream.readUTF() - val (featureCol, labelCol, predictionCol) = loadGeneralModelParams(dataInStream) - modelType match { - case "_cls_" => - val rawPredictionCol = dataInStream.readUTF() - val numClasses = dataInStream.readInt() - val thresholdLength = dataInStream.readInt() - var thresholds: Array[Double] = null - if (thresholdLength != -1) { - thresholds = new Array[Double](thresholdLength) - for (i <- 0 until thresholdLength) { - thresholds(i) = dataInStream.readDouble() - } - } - val xgBoostModel = new XGBoostClassificationModel(SXGBoost.loadModel(dataInStream)) - setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel). - asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(rawPredictionCol) - if (thresholdLength != -1) { - xgBoostModel.setThresholds(thresholds) - } - xgBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClasses - xgBoostModel - case "_reg_" => - val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream)) - setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel) - case other => - throw new XGBoostError(s"Unknown model type $other. Supported types " + - s"are: ['_reg_', '_cls_'].") - } - } } private class Watches private( @@ -489,12 +308,29 @@ private class Watches private( private object Watches { + def buildGroups(groups: Seq[Int]): Seq[Int] = { + val output = mutable.ArrayBuffer.empty[Int] + var count = 1 + var i = 1 + while (i < groups.length) { + if (groups(i) != groups(i - 1)) { + output += count + count = 1 + } else { + count += 1 + } + i += 1 + } + output += count + output + } + def apply( params: Map[String, Any], labeledPoints: Iterator[XGBLabeledPoint], baseMarginsOpt: Option[Array[Float]], cacheDirName: Option[String]): Watches = { - val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0) + val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0) val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) val r = new Random(seed) val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint] @@ -506,8 +342,18 @@ private object Watches { accepted } - val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull) + + val (trainIter1, trainIter2) = trainPoints.duplicate + val trainMatrix = new DMatrix(trainIter1, cacheDirName.map(_ + "/train").orNull) + val trainGroups = buildGroups(trainIter2.map(_.group).toSeq).toArray + trainMatrix.setGroup(trainGroups) + val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull) + if (trainTestRatio < 1.0) { + val testGroups = buildGroups(testPoints.map(_.group)).toArray + testMatrix.setGroup(testGroups) + } + r.setSeed(seed) for (baseMargins <- baseMarginsOpt) { val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio) @@ -515,11 +361,6 @@ private object Watches { testMatrix.setBaseMargin(testMargin) } - // TODO: use group attribute from the points. - if (params.contains("groupData") && params("groupData") != null) { - trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]]( - TaskContext.getPartitionId()).toArray) - } new Watches(trainMatrix, testMatrix, cacheDirName) } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassificationModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassificationModel.scala deleted file mode 100644 index 07a67db05..000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassificationModel.scala +++ /dev/null @@ -1,181 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import scala.collection.mutable -import ml.dmlc.xgboost4j.scala.Booster -import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector} -import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, Param, ParamMap} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Dataset} - -/** - * class of the XGBoost model used for classification task - */ -class XGBoostClassificationModel private[spark]( - override val uid: String, booster: Booster) - extends XGBoostModel(booster) { - - def this(booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), booster) - - // only called in copy() - def this(uid: String) = this(uid, null) - - // scalastyle:off - - /** - * whether to output raw margin - */ - final val outputMargin = new BooleanParam(this, "outputMargin", "whether to output untransformed margin value") - - setDefault(outputMargin, false) - - def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel] - - /** - * the name of the column storing the raw prediction value, either probabilities (as default) or - * raw margin value - */ - final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).") - - setDefault(rawPredictionCol, "probabilities") - - final def getRawPredictionCol: String = $(rawPredictionCol) - - def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel] - - /** - * Thresholds in multi-class classification - */ - final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0)) - - def getThresholds: Array[Double] = $(thresholds) - - def setThresholds(value: Array[Double]): XGBoostClassificationModel = - set(thresholds, value).asInstanceOf[XGBoostClassificationModel] - - // scalastyle:on - - // generate dataframe containing raw prediction column which is typed as Vector - private def predictRaw( - testSet: Dataset[_], - temporalColName: Option[String] = None, - forceTransformedScore: Option[Boolean] = None): DataFrame = { - val predictRDD = produceRowRDD(testSet, forceTransformedScore.getOrElse($(outputMargin))) - val colName = temporalColName.getOrElse($(rawPredictionCol)) - val tempColName = colName + "_arraytype" - val dsWithArrayTypedRawPredCol = testSet.sparkSession.createDataFrame(predictRDD, schema = { - testSet.schema.add(tempColName, ArrayType(FloatType, containsNull = false)) - }) - val transformerForProbabilitiesArray = - (rawPredArray: mutable.WrappedArray[Float]) => - if (numClasses == 2) { - Array(1 - rawPredArray(0), rawPredArray(0)).map(_.toDouble) - } else { - rawPredArray.map(_.toDouble).array - } - dsWithArrayTypedRawPredCol.withColumn(colName, - udf((rawPredArray: mutable.WrappedArray[Float]) => - new MLDenseVector(transformerForProbabilitiesArray(rawPredArray))).apply(col(tempColName))). - drop(tempColName) - } - - private def fromFeatureToPrediction(testSet: Dataset[_]): Dataset[_] = { - val rawPredictionDF = predictRaw(testSet, Some("rawPredictionCol")) - val predictionUDF = udf(raw2prediction _).apply(col("rawPredictionCol")) - val tempDF = rawPredictionDF.withColumn($(predictionCol), predictionUDF) - val allColumnNames = testSet.columns ++ Seq($(predictionCol)) - tempDF.select(allColumnNames(0), allColumnNames.tail: _*) - } - - private def argMax(vector: Array[Double]): Double = { - vector.zipWithIndex.maxBy(_._1)._2 - } - - private def raw2prediction(rawPrediction: MLDenseVector): Double = { - if (!isDefined(thresholds)) { - argMax(rawPrediction.values) - } else { - probability2prediction(rawPrediction) - } - } - - private def probability2prediction(probability: MLDenseVector): Double = { - if (!isDefined(thresholds)) { - argMax(probability.values) - } else { - val thresholds: Array[Double] = getThresholds - val scaledProbability = - probability.values.zip(thresholds).map { case (p, t) => - if (t == 0.0) Double.PositiveInfinity else p / t - } - argMax(scaledProbability) - } - } - - override protected def transformImpl(testSet: Dataset[_]): DataFrame = { - transformSchema(testSet.schema, logging = true) - if (isDefined(thresholds)) { - require($(thresholds).length == numClasses, this.getClass.getSimpleName + - ".transform() called with non-matching numClasses and thresholds.length." + - s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") - } - if ($(outputMargin)) { - setRawPredictionCol("margin") - } - var outputData = testSet - var numColsOutput = 0 - if ($(rawPredictionCol).nonEmpty) { - outputData = predictRaw(testSet) - numColsOutput += 1 - } - - if ($(predictionCol).nonEmpty) { - if ($(rawPredictionCol).nonEmpty) { - require(!$(outputMargin), "XGBoost does not support output final prediction with" + - " untransformed margin. Please set predictionCol as \"\" when setting outputMargin as" + - " true") - val rawToPredUDF = udf(raw2prediction _).apply(col($(rawPredictionCol))) - outputData = outputData.withColumn($(predictionCol), rawToPredUDF) - } else { - outputData = fromFeatureToPrediction(testSet) - } - numColsOutput += 1 - } - - if (numColsOutput == 0) { - this.logWarning(s"$uid: XGBoostClassificationModel.transform() was called as NOOP" + - " since no output columns were set.") - } - outputData.toDF() - } - - private[spark] var numOfClasses = 2 - - def numClasses: Int = numOfClasses - - override def copy(extra: ParamMap): XGBoostClassificationModel = { - val newModel = copyValues(new XGBoostClassificationModel(booster), extra) - newModel.setSummary(summary) - } - - override protected def predict(features: MLVector): Double = { - throw new Exception("XGBoost does not support online prediction ") - } -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala new file mode 100644 index 000000000..243f2decb --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -0,0 +1,432 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import ml.dmlc.xgboost4j.java.Rabit +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} +import ml.dmlc.xgboost4j.scala.spark.params._ +import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} + +import org.apache.hadoop.fs.Path +import org.apache.spark.TaskContext +import org.apache.spark.ml.classification._ +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.HasWeightCol +import org.apache.spark.ml.util._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql._ +import org.json4s.DefaultFormats + +private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams + with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs + +class XGBoostClassifier ( + override val uid: String, + private val xgboostParams: Map[String, Any]) + extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel] + with XGBoostClassifierParams with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("xgbc"), Map[String, Any]()) + + def this(uid: String) = this(uid, Map[String, Any]()) + + def this(xgboostParams: Map[String, Any]) = this( + Identifiable.randomUID("xgbc"), xgboostParams) + + XGBoostToMLlibParams(xgboostParams) + + def setWeightCol(value: String): this.type = set(weightCol, value) + + def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value) + + def setNumClass(value: Int): this.type = set(numClass, value) + + // setters for general params + def setNumRound(value: Int): this.type = set(numRound, value) + + def setNumWorkers(value: Int): this.type = set(numWorkers, value) + + def setNthread(value: Int): this.type = set(nthread, value) + + def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value) + + def setSilent(value: Int): this.type = set(silent, value) + + def setMissing(value: Float): this.type = set(missing, value) + + def setTimeoutRequestWorkers(value: Long): this.type = set(timeoutRequestWorkers, value) + + def setCheckpointPath(value: String): this.type = set(checkpointPath, value) + + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + def setSeed(value: Long): this.type = set(seed, value) + + // setters for booster params + def setBooster(value: String): this.type = set(booster, value) + + def setEta(value: Double): this.type = set(eta, value) + + def setGamma(value: Double): this.type = set(gamma, value) + + def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + def setMinChildWeight(value: Double): this.type = set(minChildWeight, value) + + def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value) + + def setSubsample(value: Double): this.type = set(subsample, value) + + def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value) + + def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value) + + def setLambda(value: Double): this.type = set(lambda, value) + + def setAlpha(value: Double): this.type = set(alpha, value) + + def setTreeMethod(value: String): this.type = set(treeMethod, value) + + def setGrowPolicy(value: String): this.type = set(growPolicy, value) + + def setMaxBins(value: Int): this.type = set(maxBins, value) + + def setSketchEps(value: Double): this.type = set(sketchEps, value) + + def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value) + + def setSampleType(value: String): this.type = set(sampleType, value) + + def setNormalizeType(value: String): this.type = set(normalizeType, value) + + def setRateDrop(value: Double): this.type = set(rateDrop, value) + + def setSkipDrop(value: Double): this.type = set(skipDrop, value) + + def setLambdaBias(value: Double): this.type = set(lambdaBias, value) + + // setters for learning params + def setObjective(value: String): this.type = set(objective, value) + + def setBaseScore(value: Double): this.type = set(baseScore, value) + + def setEvalMetric(value: String): this.type = set(evalMetric, value) + + def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value) + + def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value) + + // called at the start of fit/train when 'eval_metric' is not defined + private def setupDefaultEvalMetric(): String = { + require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") + if ($(objective).startsWith("multi")) { + // multi + "merror" + } else { + // binary + "error" + } + } + + override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = { + + if (!isDefined(evalMetric) || $(evalMetric).isEmpty) { + set(evalMetric, setupDefaultEvalMetric()) + } + + val _numClasses = getNumClasses(dataset) + if (isDefined(numClass) && $(numClass) != _numClasses) { + throw new Exception("The number of classes in dataset doesn't match " + + "\'num_class\' in xgboost params.") + } + + val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) { + lit(Float.NaN) + } else { + col($(baseMarginCol)) + } + + val instances: RDD[XGBLabeledPoint] = dataset.select( + col($(featuresCol)), + col($(labelCol)).cast(FloatType), + baseMargin.cast(FloatType), + weight.cast(FloatType) + ).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) => + val (indices, values) = features match { + case v: SparseVector => (v.indices, v.values.map(_.toFloat)) + case v: DenseVector => (null, v.values.map(_.toFloat)) + } + XGBLabeledPoint(label, indices, values, baseMargin = baseMargin, weight = weight) + } + transformSchema(dataset.schema, logging = true) + val derivedXGBParamMap = MLlib2XGBoostParams + // All non-null param maps in XGBoostClassifier are in derivedXGBParamMap. + val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap, + $(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory), + $(missing)) + val model = new XGBoostClassificationModel(uid, _numClasses, _booster) + val summary = XGBoostTrainingSummary(_metrics) + model.setSummary(summary) + model + } + + override def copy(extra: ParamMap): XGBoostClassifier = defaultCopy(extra) +} + +object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] { + + override def load(path: String): XGBoostClassifier = super.load(path) +} + +class XGBoostClassificationModel private[ml]( + override val uid: String, + override val numClasses: Int, + private[spark] val _booster: Booster) + extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel] + with XGBoostClassifierParams with MLWritable with Serializable { + + import XGBoostClassificationModel._ + + // only called in copy() + def this(uid: String) = this(uid, 2, null) + + private var trainingSummary: Option[XGBoostTrainingSummary] = None + + /** + * Returns summary (e.g. train/test objective history) of model on the + * training set. An exception is thrown if no summary is available. + */ + def summary: XGBoostTrainingSummary = trainingSummary.getOrElse { + throw new IllegalStateException("No training summary available for this XGBoostModel") + } + + private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = { + trainingSummary = Some(summary) + this + } + + // TODO: Make it public after we resolve performance issue + private def margin(features: Vector): Array[Float] = { + import DataUtils._ + val dm = new DMatrix(scala.collection.Iterator(features.asXGB)) + _booster.predict(data = dm, outPutMargin = true)(0) + } + + private def probability(features: Vector): Array[Float] = { + import DataUtils._ + val dm = new DMatrix(scala.collection.Iterator(features.asXGB)) + _booster.predict(data = dm, outPutMargin = false)(0) + } + + override def predict(features: Vector): Double = { + throw new Exception("XGBoost-Spark does not support online prediction") + } + + // Actually we don't use this function at all, to make it pass compiler check. + override def predictRaw(features: Vector): Vector = { + throw new Exception("XGBoost-Spark does not support \'predictRaw\'") + } + + // Actually we don't use this function at all, to make it pass compiler check. + override def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'") + } + + // Generate raw prediction and probability prediction. + private def transformInternal(dataset: Dataset[_]): DataFrame = { + + val schema = StructType(dataset.schema.fields ++ + Seq(StructField(name = _rawPredictionCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) ++ + Seq(StructField(name = _probabilityCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false))) + + val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) + val appName = dataset.sparkSession.sparkContext.appName + + val rdd = dataset.rdd.mapPartitions { rowIterator => + if (rowIterator.hasNext) { + val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + Rabit.init(rabitEnv.asJava) + val (rowItr1, rowItr2) = rowIterator.duplicate + val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector]( + $(featuresCol))).toList.iterator + import DataUtils._ + val cacheInfo = { + if ($(useExternalMemory)) { + s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}" + } else { + null + } + } + + val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo) + try { + val rawPredictionItr = { + bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator + } + val probabilityItr = { + bBooster.value.predict(dm, outPutMargin = false).map(Row(_)).iterator + } + Rabit.shutdown() + rowItr1.zip(rawPredictionItr).zip(probabilityItr).map { + case ((originals: Row, rawPrediction: Row), probability: Row) => + Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq) + } + } finally { + dm.delete() + } + } else { + Iterator[Row]() + } + } + + bBooster.unpersist(blocking = false) + + dataset.sparkSession.createDataFrame(rdd, schema) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".transform() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var outputData = transformInternal(dataset) + var numColsOutput = 0 + + val rawPredictionUDF = udf { (rawPrediction: mutable.WrappedArray[Float]) => + Vectors.dense(rawPrediction.map(_.toDouble).toArray) + } + + val probabilityUDF = udf { (probability: mutable.WrappedArray[Float]) => + if (numClasses == 2) { + Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble)) + } else { + Vectors.dense(probability.map(_.toDouble).toArray) + } + } + + val predictUDF = udf { (probability: mutable.WrappedArray[Float]) => + // From XGBoost probability to MLlib prediction + val probabilities = if (numClasses == 2) { + Array(1 - probability(0), probability(0)).map(_.toDouble) + } else { + probability.map(_.toDouble).toArray + } + probability2prediction(Vectors.dense(probabilities)) + } + + if ($(rawPredictionCol).nonEmpty) { + outputData = outputData + .withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol))) + numColsOutput += 1 + } + + if ($(probabilityCol).nonEmpty) { + outputData = outputData + .withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol))) + numColsOutput += 1 + } + + if ($(predictionCol).nonEmpty) { + outputData = outputData + .withColumn($(predictionCol), predictUDF(col(_probabilityCol))) + numColsOutput += 1 + } + + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + .toDF + .drop(col(_rawPredictionCol)) + .drop(col(_probabilityCol)) + } + + override def copy(extra: ParamMap): XGBoostClassificationModel = { + val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra) + newModel.setSummary(summary).setParent(parent) + } + + override def write: MLWriter = + new XGBoostClassificationModel.XGBoostClassificationModelWriter(this) +} + +object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] { + + private val _rawPredictionCol = "_rawPrediction" + private val _probabilityCol = "_probability" + + override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader + + override def load(path: String): XGBoostClassificationModel = super.load(path) + + private[XGBoostClassificationModel] + class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + implicit val format = DefaultFormats + implicit val sc = super.sparkSession.sparkContext + + DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc) + // Save model data + val dataPath = new Path(path, "data").toString + val internalPath = new Path(dataPath, "XGBoostClassificationModel") + val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath) + outputStream.writeInt(instance.numClasses) + instance._booster.saveModel(outputStream) + outputStream.close() + } + } + + private class XGBoostClassificationModelReader extends MLReader[XGBoostClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[XGBoostClassificationModel].getName + + override def load(path: String): XGBoostClassificationModel = { + implicit val sc = super.sparkSession.sparkContext + + + val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val internalPath = new Path(dataPath, "XGBoostClassificationModel") + val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath) + val numClasses = dataInStream.readInt() + + val booster = SXGBoost.loadModel(dataInStream) + val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster) + DefaultXGBoostParamsReader.getAndSetParams(model, metadata) + model + } + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala deleted file mode 100644 index 19b93cbb1..000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ /dev/null @@ -1,186 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import scala.collection.mutable - -import ml.dmlc.xgboost4j.scala.spark.params._ -import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} - -import org.apache.spark.ml.Predictor -import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} -import org.apache.spark.ml.param._ -import org.apache.spark.ml.util._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.FloatType -import org.apache.spark.sql.{Dataset, Row} -import org.json4s.DefaultFormats - -/** - * XGBoost Estimator to produce a XGBoost model - */ -class XGBoostEstimator private[spark]( - override val uid: String, xgboostParams: Map[String, Any]) - extends Predictor[Vector, XGBoostEstimator, XGBoostModel] - with LearningTaskParams with GeneralParams with BoosterParams with MLWritable { - - def this(xgboostParams: Map[String, Any]) = - this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any]) - - def this(uid: String) = this(uid, Map[String, Any]()) - - // called in fromXGBParamMapToParams only when eval_metric is not defined - private def setupDefaultEvalMetric(): String = { - val objFunc = xgboostParams.getOrElse("objective", xgboostParams.getOrElse("obj_type", null)) - if (objFunc == null) { - "rmse" - } else { - // compute default metric based on specified objective - val isClassificationTask = XGBoost.isClassificationTask(xgboostParams) - if (!isClassificationTask) { - // default metric for regression or ranking - if (objFunc.toString.startsWith("rank")) { - "map" - } else { - "rmse" - } - } else { - // default metric for classification - if (objFunc.toString.startsWith("multi")) { - // multi - "merror" - } else { - // binary - "error" - } - } - } - } - - private def fromXGBParamMapToParams(): Unit = { - for ((paramName, paramValue) <- xgboostParams) { - params.find(_.name == paramName) match { - case None => - case Some(_: DoubleParam) => - set(paramName, paramValue.toString.toDouble) - case Some(_: BooleanParam) => - set(paramName, paramValue.toString.toBoolean) - case Some(_: IntParam) => - set(paramName, paramValue.toString.toInt) - case Some(_: FloatParam) => - set(paramName, paramValue.toString.toFloat) - case Some(_: Param[_]) => - set(paramName, paramValue) - } - } - if (xgboostParams.get("eval_metric").isEmpty) { - set("eval_metric", setupDefaultEvalMetric()) - } - } - - fromXGBParamMapToParams() - - private[spark] def fromParamsToXGBParamMap: Map[String, Any] = { - val xgbParamMap = new mutable.HashMap[String, Any]() - for (param <- params) { - xgbParamMap += param.name -> $(param) - } - val r = xgbParamMap.toMap - if (!XGBoost.isClassificationTask(r) || $(numClasses) == 2) { - r - "num_class" - } else { - r - } - } - - private def ensureColumns(trainingSet: Dataset[_]): Dataset[_] = { - var newTrainingSet = trainingSet - if (!trainingSet.columns.contains($(baseMarginCol))) { - newTrainingSet = newTrainingSet.withColumn($(baseMarginCol), lit(Float.NaN)) - } - if (!trainingSet.columns.contains($(weightCol))) { - newTrainingSet = newTrainingSet.withColumn($(weightCol), lit(1.0)) - } - newTrainingSet - } - - /** - * produce a XGBoostModel by fitting the given dataset - */ - override def train(trainingSet: Dataset[_]): XGBoostModel = { - val instances = ensureColumns(trainingSet).select( - col($(featuresCol)), - col($(labelCol)).cast(FloatType), - col($(baseMarginCol)).cast(FloatType), - col($(weightCol)).cast(FloatType) - ).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) => - val (indices, values) = features match { - case v: SparseVector => (v.indices, v.values.map(_.toFloat)) - case v: DenseVector => (null, v.values.map(_.toFloat)) - } - XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight) - } - transformSchema(trainingSet.schema, logging = true) - val derivedXGBoosterParamMap = fromParamsToXGBParamMap - val trainedModel = XGBoost.trainDistributed(instances, derivedXGBoosterParamMap, - $(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory), - $(missing)).setParent(this) - val returnedModel = copyValues(trainedModel, extractParamMap()) - if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) { - returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses) - } - returnedModel - } - - override def copy(extra: ParamMap): XGBoostEstimator = { - defaultCopy(extra).asInstanceOf[XGBoostEstimator] - } - - override def write: MLWriter = new XGBoostEstimator.XGBoostEstimatorWriter(this) -} - -object XGBoostEstimator extends MLReadable[XGBoostEstimator] { - - override def read: MLReader[XGBoostEstimator] = new XGBoostEstimatorReader - - override def load(path: String): XGBoostEstimator = super.load(path) - - private[XGBoostEstimator] class XGBoostEstimatorWriter(instance: XGBoostEstimator) - extends MLWriter { - override protected def saveImpl(path: String): Unit = { - require(instance.fromParamsToXGBParamMap("custom_eval") == null && - instance.fromParamsToXGBParamMap("custom_obj") == null, - "we do not support persist XGBoostEstimator with customized evaluator and objective" + - " function for now") - implicit val format = DefaultFormats - implicit val sc = super.sparkSession.sparkContext - DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc) - } - } - - private class XGBoostEstimatorReader extends MLReader[XGBoostEstimator] { - - override def load(path: String): XGBoostEstimator = { - val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc) - val cls = Utils.classForName(metadata.className) - val instance = - cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] - DefaultXGBoostParamsReader.getAndSetParams(instance, metadata) - instance.asInstanceOf[XGBoostEstimator] - } - } -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala deleted file mode 100644 index 8a0d6d2e6..000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ /dev/null @@ -1,387 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import scala.collection.JavaConverters._ - -import ml.dmlc.xgboost4j.java.Rabit -import ml.dmlc.xgboost4j.scala.spark.params.{BoosterParams, DefaultXGBoostParamsWriter} -import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait} - -import org.apache.hadoop.fs.{FSDataOutputStream, Path} - -import org.apache.spark.ml.PredictionModel -import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} -import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector} -import org.apache.spark.ml.param.{BooleanParam, ParamMap, Params} -import org.apache.spark.ml.util._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ -import org.apache.spark.sql.types.{ArrayType, FloatType} -import org.apache.spark.{SparkContext, TaskContext} -import org.json4s.DefaultFormats - -/** - * the base class of [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]] - */ -abstract class XGBoostModel(protected var _booster: Booster) - extends PredictionModel[MLVector, XGBoostModel] with BoosterParams with Serializable - with Params with MLWritable { - - private var trainingSummary: Option[XGBoostTrainingSummary] = None - - /** - * Returns summary (e.g. train/test objective history) of model on the - * training set. An exception is thrown if no summary is available. - */ - def summary: XGBoostTrainingSummary = trainingSummary.getOrElse { - throw new IllegalStateException("No training summary available for this XGBoostModel") - } - - private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = { - trainingSummary = Some(summary) - this - } - - def setLabelCol(name: String): XGBoostModel = set(labelCol, name) - - // scalastyle:off - - final val useExternalMemory = new BooleanParam(this, "use_external_memory", - "whether to use external memory for prediction") - - setDefault(useExternalMemory, false) - - def setExternalMemory(value: Boolean): XGBoostModel = set(useExternalMemory, value) - - // scalastyle:on - - /** - * Predict leaf instances with the given test set (represented as RDD) - * - * @param testSet test set represented as RDD - */ - def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Float]] = { - import DataUtils._ - val broadcastBooster = testSet.sparkContext.broadcast(_booster) - testSet.mapPartitions { testSamples => - val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) - Rabit.init(rabitEnv.asJava) - if (testSamples.nonEmpty) { - val dMatrix = new DMatrix(testSamples.map(_.asXGB)) - try { - broadcastBooster.value.predictLeaf(dMatrix).iterator - } finally { - Rabit.shutdown() - dMatrix.delete() - } - } else { - Iterator() - } - } - } - - /** - * evaluate XGBoostModel with a RDD-wrapped dataset - * - * NOTE: you have to specify value of either eval or iter; when you specify both, this method - * adopts the default eval metric of model - * - * @param evalDataset the dataset used for evaluation - * @param evalName the name of evaluation - * @param evalFunc the customized evaluation function, null by default to use the default metric - * of model - * @param iter the current iteration, -1 to be null to use customized evaluation functions - * @param groupData group data specify each group size for ranking task. Top level corresponds - * to partition id, second level is the group sizes. - * @return the average metric over all partitions - */ - def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null, - iter: Int = -1, useExternalCache: Boolean = false, - groupData: Seq[Seq[Int]] = null): String = { - require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter") - val broadcastBooster = evalDataset.sparkContext.broadcast(_booster) - val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory)) - val appName = evalDataset.context.appName - val allEvalMetrics = evalDataset.mapPartitions { - labeledPointsPartition => - import DataUtils._ - if (labeledPointsPartition.hasNext) { - val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) - Rabit.init(rabitEnv.asJava) - val cacheFileName = { - if (broadcastUseExternalCache.value) { - s"$appName-${TaskContext.get().stageId()}-$evalName" + - s"-deval_cache-${TaskContext.getPartitionId()}" - } else { - null - } - } - val dMatrix = new DMatrix(labeledPointsPartition.map(_.asXGB), cacheFileName) - try { - if (groupData != null) { - dMatrix.setGroup(groupData(TaskContext.getPartitionId()).toArray) - } - (evalFunc, iter) match { - case (null, _) => { - val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter) - val Array(evName, predNumeric) = predStr.split(":") - Iterator(Some(evName, predNumeric.toFloat)) - } - case _ => { - val predictions = broadcastBooster.value.predict(dMatrix) - Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix)))) - } - } - } finally { - Rabit.shutdown() - dMatrix.delete() - } - } else { - Iterator(None) - } - }.filter(_.isDefined).collect() - val evalPrefix = allEvalMetrics.map(_.get._1).head - val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length - s"$evalPrefix = $evalMetricMean" - } - - /** - * Predict result with the given test set (represented as RDD) - * - * @param testSet test set represented as RDD - * @param missingValue the specified value to represent the missing value - */ - def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = { - val broadcastBooster = testSet.sparkContext.broadcast(_booster) - testSet.mapPartitions { testSamples => - val sampleArray = testSamples.toArray - val numRows = sampleArray.length - if (numRows == 0) { - Iterator() - } else { - val numColumns = sampleArray.head.size - val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) - Rabit.init(rabitEnv.asJava) - // translate to required format - val flatSampleArray = new Array[Float](numRows * numColumns) - for (i <- flatSampleArray.indices) { - flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat - } - val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) - try { - broadcastBooster.value.predict(dMatrix).iterator - } finally { - Rabit.shutdown() - dMatrix.delete() - } - } - } - } - - /** - * Predict result with the given test set (represented as RDD) - * - * @param testSet test set represented as RDD - * @param useExternalCache whether to use external cache for the test set - * @param outputMargin whether to output raw untransformed margin value - */ - def predict( - testSet: RDD[MLVector], - useExternalCache: Boolean = false, - outputMargin: Boolean = false): RDD[Array[Float]] = { - val broadcastBooster = testSet.sparkContext.broadcast(_booster) - val appName = testSet.context.appName - testSet.mapPartitions { testSamples => - if (testSamples.nonEmpty) { - import DataUtils._ - val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Rabit.init(rabitEnv.asJava) - val cacheFileName = { - if (useExternalCache) { - s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}" - } else { - null - } - } - val dMatrix = new DMatrix(testSamples.map(_.asXGB), cacheFileName) - try { - broadcastBooster.value.predict(dMatrix).iterator - } finally { - Rabit.shutdown() - dMatrix.delete() - } - } else { - Iterator() - } - } - } - - protected def transformImpl(testSet: Dataset[_]): DataFrame - - /** - * append leaf index of each row as an additional column in the original dataset - * - * @return the original dataframe with an additional column containing prediction results - */ - def transformLeaf(testSet: Dataset[_]): DataFrame = { - val predictRDD = produceRowRDD(testSet, predLeaf = true) - setPredictionCol("predLeaf") - transformSchema(testSet.schema, logging = true) - testSet.sparkSession.createDataFrame(predictRDD, testSet.schema.add($(predictionCol), - ArrayType(FloatType, containsNull = false))) - } - - protected def produceRowRDD(testSet: Dataset[_], outputMargin: Boolean = false, - predLeaf: Boolean = false): RDD[Row] = { - val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster) - val appName = testSet.sparkSession.sparkContext.appName - testSet.rdd.mapPartitions { - rowIterator => - if (rowIterator.hasNext) { - val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap - Rabit.init(rabitEnv.asJava) - val (rowItr1, rowItr2) = rowIterator.duplicate - val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[MLVector]( - $(featuresCol))).toList.iterator - import DataUtils._ - val cachePrefix = { - if ($(useExternalMemory)) { - s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}" - } else { - null - } - } - val testDataset = new DMatrix(vectorIterator.map(_.asXGB), cachePrefix) - try { - val rawPredictResults = { - if (!predLeaf) { - broadcastBooster.value.predict(testDataset, outputMargin).map(Row(_)).iterator - } else { - broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator - } - } - Rabit.shutdown() - // concatenate original data partition and predictions - rowItr1.zip(rawPredictResults).map { - case (originalColumns: Row, predictColumn: Row) => - Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq) - } - } finally { - testDataset.delete() - } - } else { - Iterator[Row]() - } - } - } - - /** - * produces the prediction results and append as an additional column in the original dataset - * NOTE: the prediction results is kept as the original format of xgboost - * - * @return the original dataframe with an additional column containing prediction results - */ - override def transform(testSet: Dataset[_]): DataFrame = { - transformImpl(testSet) - } - - private def saveGeneralModelParam(outputStream: FSDataOutputStream): Unit = { - outputStream.writeUTF(getFeaturesCol) - outputStream.writeUTF(getLabelCol) - outputStream.writeUTF(getPredictionCol) - } - - /** - * Save the model as to HDFS-compatible file system. - * - * @param modelPath The model path as in Hadoop path. - */ - def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = { - val path = new Path(modelPath) - val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path) - // output model type - this match { - case model: XGBoostClassificationModel => - outputStream.writeUTF("_cls_") - saveGeneralModelParam(outputStream) - outputStream.writeUTF(model.getRawPredictionCol) - outputStream.writeInt(model.numClasses) - // threshold - // threshold length - if (!isDefined(model.thresholds)) { - outputStream.writeInt(-1) - } else { - val thresholdLength = model.getThresholds.length - outputStream.writeInt(thresholdLength) - for (i <- 0 until thresholdLength) { - outputStream.writeDouble(model.getThresholds(i)) - } - } - case model: XGBoostRegressionModel => - outputStream.writeUTF("_reg_") - // eventual prediction col - saveGeneralModelParam(outputStream) - } - // booster - _booster.saveModel(outputStream) - outputStream.close() - } - - def booster: Booster = _booster - - def version: Int = this.booster.booster.getVersion - - override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra) - - override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this) -} - -object XGBoostModel extends MLReadable[XGBoostModel] { - private[spark] def apply(booster: Booster, isClassification: Boolean): XGBoostModel = { - if (!isClassification) { - new XGBoostRegressionModel(booster) - } else { - new XGBoostClassificationModel(booster) - } - } - - override def read: MLReader[XGBoostModel] = new XGBoostModelModelReader - - override def load(path: String): XGBoostModel = super.load(path) - - private[XGBoostModel] class XGBoostModelModelWriter(instance: XGBoostModel) extends MLWriter { - override protected def saveImpl(path: String): Unit = { - implicit val format = DefaultFormats - implicit val sc = super.sparkSession.sparkContext - DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc) - val dataPath = new Path(path, "data").toString - instance.saveModelAsHadoopFile(dataPath) - } - } - - private class XGBoostModelModelReader extends MLReader[XGBoostModel] { - - override def load(path: String): XGBoostModel = { - implicit val sc = super.sparkSession.sparkContext - val dataPath = new Path(path, "data").toString - // not used / all data resides in platform independent xgboost model file - // val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className) - XGBoost.loadModelFromHadoopFile(dataPath) - } - } -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressionModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressionModel.scala deleted file mode 100644 index eb4c767aa..000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressionModel.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import scala.collection.mutable - -import ml.dmlc.xgboost4j.scala.Booster -import org.apache.spark.ml.linalg.{Vector => MLVector} -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{ArrayType, FloatType} - -/** - * class of XGBoost model used for regression task - */ -class XGBoostRegressionModel private[spark](override val uid: String, booster: Booster) - extends XGBoostModel(booster) { - - def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostRegressionModel"), _booster) - - // only called in copy() - def this(uid: String) = this(uid, null) - - override protected def transformImpl(testSet: Dataset[_]): DataFrame = { - transformSchema(testSet.schema, logging = true) - val predictRDD = produceRowRDD(testSet) - val tempPredColName = $(predictionCol) + "_temp" - val transformerForArrayTypedPredCol = - udf((regressionResults: mutable.WrappedArray[Float]) => regressionResults(0)) - testSet.sparkSession.createDataFrame(predictRDD, - schema = testSet.schema.add(tempPredColName, ArrayType(FloatType, containsNull = false)) - ).withColumn( - $(predictionCol), - transformerForArrayTypedPredCol.apply(col(tempPredColName))).drop(tempPredColName) - } - - override protected def predict(features: MLVector): Double = { - throw new Exception("XGBoost does not support online prediction for now") - } - - override def copy(extra: ParamMap): XGBoostRegressionModel = { - val newModel = copyValues(new XGBoostRegressionModel(booster), extra) - newModel.setSummary(summary) - } -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala new file mode 100644 index 000000000..097a0434c --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -0,0 +1,356 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import scala.collection.JavaConverters._ + +import ml.dmlc.xgboost4j.java.Rabit +import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} +import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} + +import org.apache.hadoop.fs.Path +import org.apache.spark.TaskContext +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} +import org.apache.spark.ml.param.shared.HasWeightCol +import org.apache.spark.ml.util._ +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.json4s.DefaultFormats + +import scala.collection.mutable + +private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams + with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol + with ParamMapFuncs + +class XGBoostRegressor ( + override val uid: String, + private val xgboostParams: Map[String, Any]) + extends Predictor[Vector, XGBoostRegressor, XGBoostRegressionModel] + with XGBoostRegressorParams with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("xgbr"), Map[String, Any]()) + + def this(uid: String) = this(uid, Map[String, Any]()) + + def this(xgboostParams: Map[String, Any]) = this( + Identifiable.randomUID("xgbr"), xgboostParams) + + XGBoostToMLlibParams(xgboostParams) + + def setWeightCol(value: String): this.type = set(weightCol, value) + + def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value) + + def setGroupCol(value: String): this.type = set(groupCol, value) + + // setters for general params + def setNumRound(value: Int): this.type = set(numRound, value) + + def setNumWorkers(value: Int): this.type = set(numWorkers, value) + + def setNthread(value: Int): this.type = set(nthread, value) + + def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value) + + def setSilent(value: Int): this.type = set(silent, value) + + def setMissing(value: Float): this.type = set(missing, value) + + def setTimeoutRequestWorkers(value: Long): this.type = set(timeoutRequestWorkers, value) + + def setCheckpointPath(value: String): this.type = set(checkpointPath, value) + + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + def setSeed(value: Long): this.type = set(seed, value) + + // setters for booster params + def setBooster(value: String): this.type = set(booster, value) + + def setEta(value: Double): this.type = set(eta, value) + + def setGamma(value: Double): this.type = set(gamma, value) + + def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + def setMinChildWeight(value: Double): this.type = set(minChildWeight, value) + + def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value) + + def setSubsample(value: Double): this.type = set(subsample, value) + + def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value) + + def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value) + + def setLambda(value: Double): this.type = set(lambda, value) + + def setAlpha(value: Double): this.type = set(alpha, value) + + def setTreeMethod(value: String): this.type = set(treeMethod, value) + + def setGrowPolicy(value: String): this.type = set(growPolicy, value) + + def setMaxBins(value: Int): this.type = set(maxBins, value) + + def setSketchEps(value: Double): this.type = set(sketchEps, value) + + def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value) + + def setSampleType(value: String): this.type = set(sampleType, value) + + def setNormalizeType(value: String): this.type = set(normalizeType, value) + + def setRateDrop(value: Double): this.type = set(rateDrop, value) + + def setSkipDrop(value: Double): this.type = set(skipDrop, value) + + def setLambdaBias(value: Double): this.type = set(lambdaBias, value) + + // setters for learning params + def setObjective(value: String): this.type = set(objective, value) + + def setBaseScore(value: Double): this.type = set(baseScore, value) + + def setEvalMetric(value: String): this.type = set(evalMetric, value) + + def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value) + + def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value) + + // called at the start of fit/train when 'eval_metric' is not defined + private def setupDefaultEvalMetric(): String = { + require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") + if ($(objective).startsWith("rank")) { + "map" + } else { + "rmse" + } + } + + override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = { + + if (!isDefined(evalMetric) || $(evalMetric).isEmpty) { + set(evalMetric, setupDefaultEvalMetric()) + } + + val weight = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val baseMargin = if (!isDefined(baseMarginCol) || $(baseMarginCol).isEmpty) { + lit(Float.NaN) + } else { + col($(baseMarginCol)) + } + val group = if (!isDefined(groupCol) || $(groupCol).isEmpty) lit(-1) else col($(groupCol)) + + val instances: RDD[XGBLabeledPoint] = dataset.select( + col($(labelCol)).cast(FloatType), + col($(featuresCol)), + weight.cast(FloatType), + group.cast(IntegerType), + baseMargin.cast(FloatType) + ).rdd.map { + case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) => + val (indices, values) = features match { + case v: SparseVector => (v.indices, v.values.map(_.toFloat)) + case v: DenseVector => (null, v.values.map(_.toFloat)) + } + XGBLabeledPoint(label, indices, values, weight, group, baseMargin) + } + transformSchema(dataset.schema, logging = true) + val derivedXGBParamMap = MLlib2XGBoostParams + // All non-null param maps in XGBoostRegressor are in derivedXGBParamMap. + val (_booster, _metrics) = XGBoost.trainDistributed(instances, derivedXGBParamMap, + $(numRound), $(numWorkers), $(customObj), $(customEval), $(useExternalMemory), + $(missing)) + val model = new XGBoostRegressionModel(uid, _booster) + val summary = XGBoostTrainingSummary(_metrics) + model.setSummary(summary) + model + } + + override def copy(extra: ParamMap): XGBoostRegressor = defaultCopy(extra) +} + +object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] { + + override def load(path: String): XGBoostRegressor = super.load(path) +} + +class XGBoostRegressionModel private[ml] ( + override val uid: String, + private[spark] val _booster: Booster) + extends PredictionModel[Vector, XGBoostRegressionModel] + with XGBoostRegressorParams with MLWritable with Serializable { + + import XGBoostRegressionModel._ + + // only called in copy() + def this(uid: String) = this(uid, null) + + private var trainingSummary: Option[XGBoostTrainingSummary] = None + + /** + * Returns summary (e.g. train/test objective history) of model on the + * training set. An exception is thrown if no summary is available. + */ + def summary: XGBoostTrainingSummary = trainingSummary.getOrElse { + throw new IllegalStateException("No training summary available for this XGBoostModel") + } + + private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = { + trainingSummary = Some(summary) + this + } + + override def predict(features: Vector): Double = { + throw new Exception("XGBoost-Spark does not support online prediction") + } + + private def transformInternal(dataset: Dataset[_]): DataFrame = { + + val schema = StructType(dataset.schema.fields ++ + Seq(StructField(name = _originalPredictionCol, dataType = + ArrayType(FloatType, containsNull = false), nullable = false))) + + val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) + val appName = dataset.sparkSession.sparkContext.appName + + val rdd = dataset.rdd.mapPartitions { rowIterator => + if (rowIterator.hasNext) { + val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + Rabit.init(rabitEnv.asJava) + val (rowItr1, rowItr2) = rowIterator.duplicate + val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector]( + $(featuresCol))).toList.iterator + import DataUtils._ + val cacheInfo = { + if ($(useExternalMemory)) { + s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}" + } else { + null + } + } + + val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo) + try { + val originalPredictionItr = { + bBooster.value.predict(dm).map(Row(_)).iterator + } + Rabit.shutdown() + rowItr1.zip(originalPredictionItr).map { + case (originals: Row, originalPrediction: Row) => + Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq) + } + } finally { + dm.delete() + } + } else { + Iterator[Row]() + } + } + + bBooster.unpersist(blocking = false) + + dataset.sparkSession.createDataFrame(rdd, schema) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var outputData = transformInternal(dataset) + var numColsOutput = 0 + + val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) => + originalPrediction(0).toDouble + } + + if ($(predictionCol).nonEmpty) { + outputData = outputData + .withColumn($(predictionCol), predictUDF(col(_originalPredictionCol))) + numColsOutput += 1 + } + + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData.toDF.drop(col(_originalPredictionCol)) + } + + override def copy(extra: ParamMap): XGBoostRegressionModel = { + val newModel = copyValues(new XGBoostRegressionModel(uid, _booster), extra) + newModel.setSummary(summary).setParent(parent) + } + + override def write: MLWriter = + new XGBoostRegressionModel.XGBoostRegressionModelWriter(this) +} + +object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] { + + private val _originalPredictionCol = "_originalPrediction" + + override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader + + override def load(path: String): XGBoostRegressionModel = super.load(path) + + private[XGBoostRegressionModel] + class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + implicit val format = DefaultFormats + implicit val sc = super.sparkSession.sparkContext + DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc) + // Save model data + val dataPath = new Path(path, "data").toString + val internalPath = new Path(dataPath, "XGBoostRegressionModel") + val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath) + instance._booster.saveModel(outputStream) + outputStream.close() + } + } + + private class XGBoostRegressionModelReader extends MLReader[XGBoostRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[XGBoostRegressionModel].getName + + override def load(path: String): XGBoostRegressionModel = { + implicit val sc = super.sparkSession.sparkContext + + val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val internalPath = new Path(dataPath, "XGBoostRegressionModel") + val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath) + + val booster = SXGBoost.loadModel(dataInStream) + val model = new XGBoostRegressionModel(metadata.uid, booster) + DefaultXGBoostParamsReader.getAndSetParams(model, metadata) + model + } + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala index 6e980f8aa..0bc54fd6f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala @@ -20,40 +20,48 @@ import scala.collection.immutable.HashSet import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params} -trait BoosterParams extends Params { +private[spark] trait BoosterParams extends Params { /** * Booster to use, options: {'gbtree', 'gblinear', 'dart'} */ - val boosterType = new Param[String](this, "booster", + final val booster = new Param[String](this, "booster", s"Booster to use, options: {'gbtree', 'gblinear', 'dart'}", (value: String) => BoosterParams.supportedBoosters.contains(value.toLowerCase)) + final def getBooster: String = $(booster) + /** * step size shrinkage used in update to prevents overfitting. After each boosting step, we * can directly get the weights of new features and eta actually shrinks the feature weights * to make the boosting process more conservative. [default=0.3] range: [0,1] */ - val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" + + final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" + " overfitting. After each boosting step, we can directly get the weights of new features." + " and eta actually shrinks the feature weights to make the boosting process more conservative.", (value: Double) => value >= 0 && value <= 1) + final def getEta: Double = $(eta) + /** * minimum loss reduction required to make a further partition on a leaf node of the tree. * the larger, the more conservative the algorithm will be. [default=0] range: [0, * Double.MaxValue] */ - val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a further" + - " partition on a leaf node of the tree. the larger, the more conservative the algorithm" + - " will be.", (value: Double) => value >= 0) + final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " + + "further partition on a leaf node of the tree. the larger, the more conservative the " + + "algorithm will be.", (value: Double) => value >= 0) + + final def getGamma: Double = $(gamma) /** * maximum depth of a tree, increase this value will make model more complex / likely to be * overfitting. [default=6] range: [1, Int.MaxValue] */ - val maxDepth = new IntParam(this, "max_depth", "maximum depth of a tree, increase this value" + - " will make model more complex/likely to be overfitting.", (value: Int) => value >= 1) + final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " + + "value will make model more complex/likely to be overfitting.", (value: Int) => value >= 1) + + final def getMaxDepth: Int = $(maxDepth) /** * minimum sum of instance weight(hessian) needed in a child. If the tree partition step results @@ -62,13 +70,15 @@ trait BoosterParams extends Params { * to minimum number of instances needed to be in each node. The larger, the more conservative * the algorithm will be. [default=1] range: [0, Double.MaxValue] */ - val minChildWeight = new DoubleParam(this, "min_child_weight", "minimum sum of instance" + + final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" + " weight(hessian) needed in a child. If the tree partition step results in a leaf node with" + " the sum of instance weight less than min_child_weight, then the building process will" + " give up further partitioning. In linear regression mode, this simply corresponds to minimum" + " number of instances needed to be in each node. The larger, the more conservative" + " the algorithm will be.", (value: Double) => value >= 0) + final def getMinChildWeight: Double = $(minChildWeight) + /** * Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it * means there is no constraint. If it is set to a positive value, it can help making the update @@ -76,90 +86,113 @@ trait BoosterParams extends Params { * regression when class is extremely imbalanced. Set it to value of 1-10 might help control the * update. [default=0] range: [0, Double.MaxValue] */ - val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow each" + - " tree's weight" + + final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " + + "each tree's weight" + " estimation to be. If the value is set to 0, it means there is no constraint. If it is set" + " to a positive value, it can help making the update step more conservative. Usually this" + " parameter is not needed, but it might help in logistic regression when class is extremely" + " imbalanced. Set it to value of 1-10 might help control the update", (value: Double) => value >= 0) + final def getMaxDeltaStep: Double = $(maxDeltaStep) + /** * subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly * collected half of the data instances to grow trees and this will prevent overfitting. * [default=1] range:(0,1] */ - val subSample = new DoubleParam(this, "subsample", "subsample ratio of the training instance." + - " Setting it to 0.5 means that XGBoost randomly collected half of the data instances to" + - " grow trees and this will prevent overfitting.", (value: Double) => value <= 1 && value > 0) + final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " + + "instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " + + "instances to grow trees and this will prevent overfitting.", + (value: Double) => value <= 1 && value > 0) + + final def getSubsample: Double = $(subsample) /** * subsample ratio of columns when constructing each tree. [default=1] range: (0,1] */ - val colSampleByTree = new DoubleParam(this, "colsample_bytree", "subsample ratio of columns" + - " when constructing each tree.", (value: Double) => value <= 1 && value > 0) + final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " + + "columns when constructing each tree.", (value: Double) => value <= 1 && value > 0) + + final def getColsampleBytree: Double = $(colsampleBytree) /** * subsample ratio of columns for each split, in each level. [default=1] range: (0,1] */ - val colSampleByLevel = new DoubleParam(this, "colsample_bylevel", "subsample ratio of columns" + - " for each split, in each level.", (value: Double) => value <= 1 && value > 0) + final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " + + "columns for each split, in each level.", (value: Double) => value <= 1 && value > 0) + + final def getColsampleBylevel: Double = $(colsampleBylevel) /** * L2 regularization term on weights, increase this value will make model more conservative. * [default=1] */ - val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, increase this" + - " value will make model more conservative.", (value: Double) => value >= 0) + final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " + + "increase this value will make model more conservative.", (value: Double) => value >= 0) + + final def getLambda: Double = $(lambda) /** * L1 regularization term on weights, increase this value will make model more conservative. * [default=0] */ - val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase this" + - " value will make model more conservative.", (value: Double) => value >= 0) + final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " + + "this value will make model more conservative.", (value: Double) => value >= 0) + + final def getAlpha: Double = $(alpha) /** * The tree construction algorithm used in XGBoost. options: {'auto', 'exact', 'approx'} * [default='auto'] */ - val treeMethod = new Param[String](this, "tree_method", + final val treeMethod = new Param[String](this, "treeMethod", "The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist'}", (value: String) => BoosterParams.supportedTreeMethods.contains(value)) + final def getTreeMethod: String = $(treeMethod) + /** * growth policy for fast histogram algorithm */ - val growthPolicty = new Param[String](this, "grow_policy", + final val growPolicy = new Param[String](this, "growPolicy", "growth policy for fast histogram algorithm", (value: String) => BoosterParams.supportedGrowthPolicies.contains(value)) + final def getGrowPolicy: String = $(growPolicy) + /** * maximum number of bins in histogram */ - val maxBins = new IntParam(this, "max_bin", "maximum number of bins in histogram", + final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram", (value: Int) => value > 0) + final def getMaxBins: Int = $(maxBins) + /** * This is only used for approximate greedy algorithm. * This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select * number of bins, this comes with theoretical guarantee with sketch accuracy. * [default=0.03] range: (0, 1) */ - val sketchEps = new DoubleParam(this, "sketch_eps", + final val sketchEps = new DoubleParam(this, "sketchEps", "This is only used for approximate greedy algorithm. This roughly translated into" + " O(1 / sketch_eps) number of bins. Compared to directly select number of bins, this comes" + " with theoretical guarantee with sketch accuracy.", (value: Double) => value < 1 && value > 0) + final def getSketchEps: Double = $(sketchEps) + /** * Control the balance of positive and negative weights, useful for unbalanced classes. A typical * value to consider: sum(negative cases) / sum(positive cases). [default=1] */ - val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of positive" + - " and negative weights, useful for unbalanced classes. A typical value to consider:" + + final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " + + "positive and negative weights, useful for unbalanced classes. A typical value to consider:" + " sum(negative cases) / sum(positive cases)") + final def getScalePosWeight: Double = $(scalePosWeight) + // Dart boosters /** @@ -167,72 +200,59 @@ trait BoosterParams extends Params { * Type of sampling algorithm. "uniform": dropped trees are selected uniformly. * "weighted": dropped trees are selected in proportion to weight. [default="uniform"] */ - val sampleType = new Param[String](this, "sample_type", "type of sampling algorithm, options:" + - " {'uniform', 'weighted'}", + final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " + + "options: {'uniform', 'weighted'}", (value: String) => BoosterParams.supportedSampleType.contains(value)) + final def getSampleType: String = $(sampleType) + /** * Parameter of Dart booster. * type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"] */ - val normalizeType = new Param[String](this, "normalize_type", "type of normalization" + + final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" + " algorithm, options: {'tree', 'forest'}", (value: String) => BoosterParams.supportedNormalizeType.contains(value)) + final def getNormalizeType: String = $(normalizeType) + /** * Parameter of Dart booster. * dropout rate. [default=0.0] range: [0.0, 1.0] */ - val rateDrop = new DoubleParam(this, "rate_drop", "dropout rate", (value: Double) => + final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) => value >= 0 && value <= 1) + final def getRateDrop: Double = $(rateDrop) + /** * Parameter of Dart booster. * probability of skip dropout. If a dropout is skipped, new trees are added in the same manner * as gbtree. [default=0.0] range: [0.0, 1.0] */ - val skipDrop = new DoubleParam(this, "skip_drop", "probability of skip dropout. If" + + final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" + " a dropout is skipped, new trees are added in the same manner as gbtree.", (value: Double) => value >= 0 && value <= 1) + final def getSkipDrop: Double = $(skipDrop) + // linear booster /** * Parameter of linear booster * L2 regularization term on bias, default 0(no L1 reg on bias because it is not important) */ - val lambdaBias = new DoubleParam(this, "lambda_bias", "L2 regularization term on bias, default" + - " 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0) + final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " + + "default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0) - setDefault(boosterType -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6, + final def getLambdaBias: Double = $(lambdaBias) + + setDefault(booster -> "gbtree", eta -> 0.3, gamma -> 0, maxDepth -> 6, minChildWeight -> 1, maxDeltaStep -> 0, - growthPolicty -> "depthwise", maxBins -> 16, - subSample -> 1, colSampleByTree -> 1, colSampleByLevel -> 1, + growPolicy -> "depthwise", maxBins -> 16, + subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1, lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03, scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree", rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0) - - /** - * Explains all params of this instance. See `explainParam()`. - */ - override def explainParams(): String = { - // TODO: filter some parameters according to the booster type - val boosterTypeStr = $(boosterType) - val validParamList = { - if (boosterTypeStr == "gblinear") { - // gblinear - params.filter(param => param.name == "lambda" || - param.name == "alpha" || param.name == "lambda_bias") - } else if (boosterTypeStr != "dart") { - // gbtree - params.filter(param => param.name != "sample_type" && - param.name != "normalize_type" && param.name != "rate_drop" && param.name != "skip_drop") - } else { - // dart - params.filter(_.name != "lambda_bias") - } - } - explainParam(boosterType) + "\n" ++ validParamList.map(explainParam).mkString("\n") - } } private[spark] object BoosterParams { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 90a943b49..4dc9e7a39 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -16,84 +16,104 @@ package ml.dmlc.xgboost4j.scala.spark.params +import com.google.common.base.CaseFormat import ml.dmlc.xgboost4j.scala.spark.TrackerConf - import org.apache.spark.ml.param._ +import scala.collection.mutable -trait GeneralParams extends Params { +private[spark] trait GeneralParams extends Params { /** * The number of rounds for boosting */ - val round = new IntParam(this, "num_round", "The number of rounds for boosting", + final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting", ParamValidators.gtEq(1)) + final def getNumRound: Int = $(numRound) + /** * number of workers used to train xgboost model. default: 1 */ - val nWorkers = new IntParam(this, "nworkers", "number of workers used to run xgboost", + final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost", ParamValidators.gtEq(1)) + final def getNumWorkers: Int = $(numWorkers) + /** * number of threads used by per worker. default 1 */ - val numThreadPerTask = new IntParam(this, "nthread", "number of threads used by per worker", + final val nthread = new IntParam(this, "nthread", "number of threads used by per worker", ParamValidators.gtEq(1)) + final def getNthread: Int = $(nthread) + /** * whether to use external memory as cache. default: false */ - val useExternalMemory = new BooleanParam(this, "use_external_memory", "whether to use external" + - "memory as cache") + final val useExternalMemory = new BooleanParam(this, "useExternalMemory", + "whether to use external memory as cache") + + final def getUseExternalMemory: Boolean = $(useExternalMemory) /** * 0 means printing running messages, 1 means silent mode. default: 0 */ - val silent = new IntParam(this, "silent", + final val silent = new IntParam(this, "silent", "0 means printing running messages, 1 means silent mode.", (value: Int) => value >= 0 && value <= 1) + final def getSilent: Int = $(silent) + /** * customized objective function provided by user. default: null */ - val customObj = new CustomObjParam(this, "custom_obj", "customized objective function " + + final val customObj = new CustomObjParam(this, "customObj", "customized objective function " + "provided by user") /** * customized evaluation function provided by user. default: null */ - val customEval = new CustomEvalParam(this, "custom_eval", "customized evaluation function " + - "provided by user") + final val customEval = new CustomEvalParam(this, "customEval", + "customized evaluation function provided by user") /** * the value treated as missing. default: Float.NaN */ - val missing = new FloatParam(this, "missing", "the value treated as missing") + final val missing = new FloatParam(this, "missing", "the value treated as missing") + + final def getMissing: Float = $(missing) /** * the maximum time to wait for the job requesting new workers. default: 30 minutes */ - val timeoutRequestWorkers = new LongParam(this, "timeout_request_workers", "the maximum time to" + - " request new Workers if numCores are insufficient. The timeout will be disabled if this" + - " value is set smaller than or equal to 0.") + final val timeoutRequestWorkers = new LongParam(this, "timeoutRequestWorkers", "the maximum " + + "time to request new Workers if numCores are insufficient. The timeout will be disabled " + + "if this value is set smaller than or equal to 0.") + + final def getTimeoutRequestWorkers: Long = $(timeoutRequestWorkers) /** * The hdfs folder to load and save checkpoint boosters. default: `empty_string` */ - val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " + - "save checkpoints. If there are existing checkpoints in checkpoint_path. The job will load " + - "the checkpoint with highest version as the starting point for training. If " + + final val checkpointPath = new Param[String](this, "checkpointPath", "the hdfs folder to load " + + "and save checkpoints. If there are existing checkpoints in checkpoint_path. The job will " + + "load the checkpoint with highest version as the starting point for training. If " + "checkpoint_interval is also set, the job will save a checkpoint every a few rounds.") + final def getCheckpointPath: String = $(checkpointPath) + /** * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that * the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must * also be set if the checkpoint interval is greater than 0. */ - val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint " + - "interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained model will get " + - "checkpointed every 10 iterations. Note: `checkpoint_path` must also be set if the checkpoint" + - " interval is greater than 0.", (interval: Int) => interval == -1 || interval >= 1) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", + "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained " + + "model will get checkpointed every 10 iterations. Note: `checkpoint_path` must also be " + + "set if the checkpoint interval is greater than 0.", + (interval: Int) => interval == -1 || interval >= 1) + + final def getCheckpointInterval: Int = $(checkpointInterval) /** * Rabit tracker configurations. The parameter must be provided as an instance of the @@ -122,15 +142,87 @@ trait GeneralParams extends Params { * Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf). * Ignored if the tracker implementation is "python". */ - val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations") + final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations") /** Random seed for the C++ part of XGBoost and train/test splitting. */ - val seed = new LongParam(this, "seed", "random seed") + final val seed = new LongParam(this, "seed", "random seed") - setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1, + final def getSeed: Long = $(seed) + + setDefault(numRound -> 1, numWorkers -> 1, nthread -> 1, useExternalMemory -> false, silent -> 0, customObj -> null, customEval -> null, missing -> Float.NaN, trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L, checkpointPath -> "", checkpointInterval -> -1 ) } + +trait HasBaseMarginCol extends Params { + + /** + * Param for initial prediction (aka base margin) column name. + * @group param + */ + final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol", + "Initial prediction (aka base margin) column name.") + + /** @group getParam */ + final def getBaseMarginCol: String = $(baseMarginCol) +} + +trait HasGroupCol extends Params { + + /** + * Param for group column name. + * @group param + */ + final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.") + + /** @group getParam */ + final def getGroupCol: String = $(groupCol) + +} + +trait HasNumClass extends Params { + + /** + * number of classes + */ + final val numClass = new IntParam(this, "numClass", "number of classes") + + /** @group getParam */ + final def getNumClass: Int = $(numClass) +} + +private[spark] trait ParamMapFuncs extends Params { + + def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = { + for ((paramName, paramValue) <- xgboostParams) { + val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) + params.find(_.name == name) match { + case None => + case Some(_: DoubleParam) => + set(name, paramValue.toString.toDouble) + case Some(_: BooleanParam) => + set(name, paramValue.toString.toBoolean) + case Some(_: IntParam) => + set(name, paramValue.toString.toInt) + case Some(_: FloatParam) => + set(name, paramValue.toString.toFloat) + case Some(_: Param[_]) => + set(name, paramValue) + } + } + } + + def MLlib2XGBoostParams: Map[String, Any] = { + val xgboostParams = new mutable.HashMap[String, Any]() + for (param <- params) { + if (isDefined(param)) { + val name = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, param.name) + xgboostParams += name -> $(param) + } + } + xgboostParams.toMap + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 0c5055bf5..e477b242f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -20,76 +20,70 @@ import scala.collection.immutable.HashSet import org.apache.spark.ml.param._ -trait LearningTaskParams extends Params { - - /** - * number of tasks to learn - */ - val numClasses = new IntParam(this, "num_class", "number of classes") +private[spark] trait LearningTaskParams extends Params { /** * Specify the learning task and the corresponding learning objective. * options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson, * multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:linear */ - val objective = new Param[String](this, "objective", "objective function used for training," + - s" options: {${LearningTaskParams.supportedObjective.mkString(",")}", + final val objective = new Param[String](this, "objective", "objective function used for " + + s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}", (value: String) => LearningTaskParams.supportedObjective.contains(value)) + final def getObjective: String = $(objective) + /** * the initial prediction score of all instances, global bias. default=0.5 */ - val baseScore = new DoubleParam(this, "base_score", "the initial prediction score of all" + + final val baseScore = new DoubleParam(this, "baseScore", "the initial prediction score of all" + " instances, global bias") + final def getBaseScore: Double = $(baseScore) + /** * evaluation metrics for validation data, a default metric will be assigned according to * objective(rmse for regression, and error for classification, mean average precision for * ranking). options: rmse, mae, logloss, error, merror, mlogloss, auc, aucpr, ndcg, map, * gamma-deviance */ - val evalMetric = new Param[String](this, "eval_metric", "evaluation metrics for validation" + - " data, a default metric will be assigned according to objective (rmse for regression, and" + - " error for classification, mean average precision for ranking), options: " + - s" {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}", + final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " + + "validation data, a default metric will be assigned according to objective " + + "(rmse for regression, and error for classification, mean average precision for ranking), " + + s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}", (value: String) => LearningTaskParams.supportedEvalMetrics.contains(value)) + final def getEvalMetric: String = $(evalMetric) + /** * group data specify each group sizes for ranking task. To correspond to partition of * training data, it is nested. */ - val groupData = new GroupDataParam(this, "groupData", "group data specify each group size" + - " for ranking task. To correspond to partition of training data, it is nested.") - - /** - * Initial prediction (aka base margin) column name. - */ - val baseMarginCol = new Param[String](this, "baseMarginCol", "base margin column name") - - /** - * Instance weights column name. - */ - val weightCol = new Param[String](this, "weightCol", "weight column name") + final val groupData = new GroupDataParam(this, "groupData", "group data specify each group " + + "size for ranking task. To correspond to partition of training data, it is nested.") /** * Fraction of training points to use for testing. */ - val trainTestRatio = new DoubleParam(this, "trainTestRatio", + final val trainTestRatio = new DoubleParam(this, "trainTestRatio", "fraction of training points to use for testing", ParamValidators.inRange(0, 1)) + final def getTrainTestRatio: Double = $(trainTestRatio) + /** * If non-zero, the training will be stopped after a specified number * of consecutive increases in any evaluation metric. */ - val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds", + final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds", "number of rounds of decreasing eval metric to tolerate before " + "stopping the training", (value: Int) => value == 0 || value > 1) - setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null, - baseMarginCol -> "baseMargin", weightCol -> "weight", trainTestRatio -> 1.0, - numEarlyStoppingRounds -> 0) + final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds) + + setDefault(objective -> "reg:linear", baseScore -> 0.5, groupData -> null, + trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0) } private[spark] object LearningTaskParams { diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train deleted file mode 100644 index 1f31343dd..000000000 --- a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train +++ /dev/null @@ -1,75 +0,0 @@ -0 1:985.574005058 2:320.223538037 3:0.621236086198 -0 1:1010.52917943 2:635.535543082 3:2.14984030531 -0 1:1012.91900422 2:132.387300057 3:0.488761066665 -0 1:990.829194034 2:135.102081162 3:0.747701610673 -0 1:1007.05103629 2:154.289183562 3:0.464118249201 -0 1:994.9573036 2:317.483732878 3:0.0313685555674 -0 1:987.8071541 2:731.349178363 3:0.244616944245 -1 1:10.0349544469 2:2.29750906143 3:36.4949974282 -0 1:9.92953881383 2:5.39134047297 3:120.041297548 -0 1:10.0909866713 2:9.06191026312 3:138.807825798 -1 1:10.2090970614 2:0.0784495944448 3:58.207703565 -0 1:9.85695905893 2:9.99500727713 3:56.8610243778 -1 1:10.0805758547 2:0.0410805760559 3:222.102302076 -0 1:10.1209914486 2:9.9729127088 3:171.888238763 -0 1:10.0331939798 2:0.853339303793 3:311.181328375 -0 1:9.93901762951 2:2.72757449146 3:78.4859514413 -0 1:10.0752365346 2:9.18695328235 3:49.8520256553 -1 1:10.0456548902 2:0.270936043122 3:123.462958597 -0 1:10.0568923673 2:0.82997113263 3:44.9391426001 -0 1:9.8214143472 2:0.277538931578 3:15.4217659578 -0 1:9.95258604431 2:8.69564346094 3:255.513470671 -0 1:9.91934976357 2:7.72809741413 3:82.171591817 -0 1:10.043239582 2:8.64168255553 3:38.9657919329 -1 1:10.0236147929 2:0.0496662263659 3:4.40889812286 -1 1:1001.85585324 2:3.75646886071 3:0.0179224994842 -0 1:1014.25578571 2:0.285765311201 3:0.510329864983 -1 1:1002.81422786 2:9.77676280375 3:0.433705951912 -1 1:998.072711553 2:2.82100686538 3:0.889829076909 -0 1:1003.77395036 2:2.55916592114 3:0.0359402151496 -1 1:10.0807877782 2:4.98513959013 3:47.5266363559 -0 1:10.0015013081 2:9.94302478763 3:78.3697486277 -1 1:10.0441936789 2:0.305091816635 3:56.8213984987 -0 1:9.94257106618 2:7.23909568913 3:442.463339039 -1 1:9.86479307916 2:6.41701315844 3:55.1365304834 -0 1:10.0428628516 2:9.98466447697 3:0.391632812588 -0 1:9.94445884566 2:9.99970945878 3:260.438436534 -1 1:9.84641392823 2:225.78051312 3:1.00525978847 -1 1:9.86907690608 2:26.8971083147 3:0.577959255991 -0 1:10.0177314626 2:0.110585342313 3:2.30545043031 -0 1:10.0688190907 2:412.023866234 3:1.22421542264 -0 1:10.1251769646 2:13.8212202925 3:0.129171734504 -0 1:10.0840758802 2:407.359097187 3:0.477000870705 -0 1:10.1007458705 2:987.183625145 3:0.149385677415 -0 1:9.86472656059 2:169.559640615 3:0.147221652519 -0 1:9.94207419238 2:507.290053755 3:0.41996207214 -0 1:9.9671005502 2:1.62610457716 3:0.408173666788 -0 1:1010.57126596 2:9.06673707562 3:0.672092284372 -0 1:1001.6718262 2:9.53203990055 3:4.7364050044 -0 1:995.777341384 2:4.43847316256 3:2.07229073634 -0 1:1002.95701386 2:5.51711016665 3:1.24294450546 -0 1:1016.0988238 2:0.626468941906 3:0.105627919134 -0 1:1013.67571419 2:0.042315529666 3:0.717619310322 -1 1:994.747747892 2:6.01989364024 3:0.772910130015 -1 1:991.654593872 2:7.35575736952 3:1.19822091548 -0 1:1008.47101732 2:8.28240754909 3:0.229582481359 -0 1:1000.81975227 2:1.52448354056 3:0.096441660362 -0 1:10.0900922344 2:322.656649307 3:57.8149073088 -1 1:10.0868337371 2:2.88652339174 3:54.8865514572 -0 1:10.0988984137 2:979.483832657 3:52.6809830901 -0 1:9.97678959238 2:665.770979738 3:481.069628909 -0 1:9.78554312773 2:257.309358658 3:47.7324475232 -0 1:10.0985967566 2:935.896512941 3:138.937052808 -0 1:10.0522252319 2:876.376299607 3:6.00373510669 -1 1:9.88065229501 2:9.99979825653 3:0.0674603696149 -0 1:10.0483244098 2:0.0653852316381 3:0.130679349938 -1 1:9.99685215607 2:1.76602542774 3:0.2551321159 -0 1:9.99750159428 2:1.01591534436 3:0.145445506504 -1 1:9.97380908941 2:0.940048645571 3:0.411805696316 -0 1:9.99977678382 2:6.91329929641 3:5.57858201258 -0 1:978.876096381 2:933.775364741 3:0.579170824236 -0 1:998.381016406 2:220.940470582 3:2.01491778565 -0 1:987.917644594 2:8.74667873567 3:0.364006099758 -0 1:1000.20994892 2:25.2945450565 3:3.5684398964 -0 1:1014.57141264 2:675.593540733 3:0.164174055535 -0 1:998.867283535 2:765.452750642 3:0.818425293238 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train.group b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train.group deleted file mode 100644 index 67e55b03b..000000000 --- a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-0.txt.train.group +++ /dev/null @@ -1,10 +0,0 @@ -7 -7 -10 -5 -7 -10 -10 -7 -6 -6 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train deleted file mode 100644 index 44c0f1ae3..000000000 --- a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train +++ /dev/null @@ -1,74 +0,0 @@ -0 1:10.2143092481 2:273.576539531 3:137.111774354 -0 1:10.0366658918 2:842.469052609 3:2.32134375927 -0 1:10.1281202091 2:395.654057342 3:35.4184893063 -0 1:10.1443721289 2:960.058461049 3:272.887070637 -0 1:10.1353234784 2:535.51304462 3:2.15393842032 -1 1:10.0451640374 2:216.733858424 3:55.6533298016 -1 1:9.94254592171 2:44.5985537358 3:304.614176871 -0 1:10.1319257181 2:613.545504487 3:5.42391587912 -0 1:1020.63622468 2:997.476744201 3:0.509425590461 -0 1:986.304585519 2:822.669937965 3:0.605133561808 -1 1:1012.66863221 2:26.7185759069 3:0.0875458784828 -0 1:995.387656321 2:81.8540176995 3:0.691999430068 -0 1:1020.6587198 2:848.826964547 3:0.540159430526 -1 1:1003.81573853 2:379.84350931 3:0.0083682925194 -0 1:1021.60921516 2:641.376951467 3:1.12339054807 -0 1:1000.17585041 2:122.107138713 3:1.09906375372 -1 1:987.64802348 2:5.98448541152 3:0.124241987204 -1 1:9.94610136583 2:346.114985897 3:0.387708236565 -0 1:9.96812192337 2:313.278109696 3:0.00863026595671 -0 1:10.0181739194 2:36.7378924562 3:2.92179879835 -0 1:9.89000102695 2:164.273723971 3:0.685222591968 -0 1:10.1555212436 2:320.451459462 3:2.01341536261 -0 1:10.0085727613 2:999.767117646 3:0.462294934168 -1 1:9.93099658724 2:5.17478203909 3:0.213855205032 -0 1:10.0629454957 2:663.088181857 3:0.049022351462 -0 1:10.1109732417 2:734.904569784 3:1.6998450094 -0 1:1006.6015266 2:505.023453703 3:1.90870566777 -0 1:991.865769489 2:245.437343115 3:0.475109744256 -0 1:998.682734072 2:950.041057232 3:1.9256314201 -0 1:1005.02207209 2:2.9619314197 3:0.0517146822357 -0 1:1002.54526214 2:860.562681899 3:0.915687092848 -0 1:1000.38847359 2:808.416525088 3:0.209690673808 -1 1:992.557818382 2:373.889409453 3:0.107571728577 -0 1:1002.07722137 2:997.329626371 3:1.06504260496 -0 1:1000.40504333 2:949.832139189 3:0.539159980327 -0 1:10.1460179902 2:8.86082969819 3:135.953842715 -1 1:9.98529296553 2:2.87366448495 3:1.74249892194 -0 1:9.88942676744 2:9.4031821056 3:149.473066381 -1 1:10.0192953341 2:1.99685737576 3:1.79502473397 -0 1:10.0110654379 2:8.13112593726 3:87.7765628103 -0 1:997.148677047 2:733.936190093 3:1.49298494242 -0 1:1008.70465919 2:957.121652078 3:0.217414013634 -1 1:997.356154278 2:541.599587807 3:0.100855972216 -0 1:999.615897283 2:943.700501824 3:0.862874175879 -1 1:997.36859077 2:0.200859940848 3:0.13601892182 -0 1:10.0423255624 2:1.73855202168 3:0.956695338485 -1 1:9.88440755486 2:9.9994600678 3:0.305080529665 -0 1:10.0891026412 2:3.28031719474 3:0.364450973697 -0 1:9.90078644258 2:8.77839663617 3:0.456660574479 -1 1:9.79380029711 2:8.77220326156 3:0.527292005175 -0 1:9.93613887011 2:9.76270841268 3:1.40865693823 -0 1:10.0009239007 2:7.29056178263 3:0.498015866607 -0 1:9.96603319905 2:5.12498000925 3:0.517492532783 -0 1:10.0923827222 2:2.76652583955 3:1.56571226159 -1 1:10.0983782035 2:587.788120694 3:0.031756483687 -1 1:9.91397225464 2:994.527496819 3:3.72092164978 -0 1:10.1057472738 2:2.92894440088 3:0.683506438532 -0 1:10.1014053354 2:959.082038017 3:1.07039624129 -0 1:10.1433253044 2:322.515119317 3:0.51408278993 -1 1:9.82832510699 2:637.104433908 3:0.250272776427 -0 1:1000.49729075 2:2.75336888111 3:0.576634423274 -1 1:984.90338088 2:0.0295435794035 3:1.26273339929 -0 1:1001.53811442 2:4.64164410861 3:0.0293389959504 -1 1:995.875898395 2:5.08223403205 3:0.382330566779 -0 1:996.405937252 2:6.26395190757 3:0.453645816611 -0 1:10.0165140779 2:340.126072514 3:0.220794603312 -0 1:9.93482824816 2:951.672000448 3:0.124406293612 -0 1:10.1700278554 2:0.0140985961008 3:0.252452256311 -0 1:9.99825079542 2:950.382643896 3:0.875382402062 -0 1:9.87316410028 2:686.788257829 3:0.215886999825 -0 1:10.2893240654 2:89.3947931451 3:0.569578232133 -0 1:9.98689192703 2:0.430107535413 3:2.99869831728 -0 1:10.1365175107 2:972.279245093 3:0.0865099386744 -0 1:9.90744703306 2:50.810461183 3:3.00863325197 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train.group b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train.group deleted file mode 100644 index 877ef9231..000000000 --- a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo-1.txt.train.group +++ /dev/null @@ -1,10 +0,0 @@ -8 -9 -9 -9 -5 -5 -9 -6 -5 -9 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test.group b/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test.group deleted file mode 100644 index 81e3e05be..000000000 --- a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test.group +++ /dev/null @@ -1,10 +0,0 @@ -7 -5 -9 -6 -6 -8 -7 -6 -5 -7 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank.test.csv b/jvm-packages/xgboost4j-spark/src/test/resources/rank.test.csv new file mode 100644 index 000000000..83bf8b080 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/rank.test.csv @@ -0,0 +1,66 @@ +0,10.0229017899,7.30178495562,0.118115020017,1 +0,9.93639621859,9.93102159291,0.0435030004396,1 +0,10.1301737265,0.00411765220572,2.4165878053,1 +1,9.87828587087,0.608588414992,0.111262590883,1 +0,10.1373430048,0.47764012225,0.991553052194,1 +0,10.0523814718,4.72152505167,0.672978832666,1 +0,10.0449715742,8.40373928536,0.384457573667,1 +1,996.398498791,941.976309154,0.230269231292,2 +0,1005.11269468,900.093680877,0.265031528873,2 +0,997.160349441,891.331101688,2.19362017313,2 +0,993.754139031,44.8000165317,1.03868009875,2 +1,994.831299184,241.959208453,0.667631827024,2 +0,995.948333283,7.94326917112,0.750490877118,3 +0,989.733981273,7.52077625436,0.0126335967282,3 +0,1003.54086516,6.48177510564,1.19441696788,3 +0,996.56177804,9.71959812613,1.33082465111,3 +0,1005.61382467,0.234339369309,1.17987797356,3 +1,980.215758708,6.85554542926,2.63965085259,3 +1,987.776408872,2.23354609991,0.841885278028,3 +0,1006.54260396,8.12142049834,2.26639471174,3 +0,1009.87927639,6.40028519044,0.775155669615,3 +0,9.95006244393,928.76896718,234.948458244,4 +1,10.0749152258,255.294574476,62.9728604166,4 +1,10.1916541988,312.682867085,92.299413677,4 +0,9.95646724484,742.263188416,53.3310473654,4 +0,9.86211293222,996.237023866,2.00760301168,4 +1,9.91801019468,303.971783709,50.3147230679,4 +0,996.983996934,9.52188222766,1.33588120981,5 +0,995.704388126,9.49260524915,0.908498516541,5 +0,987.86480767,0.0870786716821,0.108859297837,5 +0,1000.99561307,2.85272694575,0.171134518956,5 +0,1011.05508066,7.55336771768,1.04950084825,5 +1,985.52199365,0.763305780608,1.7402424375,5 +0,10.0430321467,813.185427181,4.97728254185,6 +0,10.0812334228,258.297288417,0.127477670549,6 +0,9.84210504292,887.205815261,0.991689193955,6 +1,9.94625332613,0.298622762132,0.147881353231,6 +0,9.97800659954,727.619819757,0.0718361141866,6 +1,9.8037938472,957.385549617,0.0618862028941,6 +0,10.0880634741,185.024638577,1.7028095095,6 +0,9.98630799154,109.10631473,0.681117359751,6 +0,9.91671416638,166.248076588,122.538291094,7 +0,10.1206910464,88.1539468531,141.189859069,7 +1,10.1767160518,1.02960996847,172.02256237,7 +0,9.93025147233,391.196641942,58.040338247,7 +0,9.84850936037,474.63346537,17.5627875397,7 +1,9.8162731343,61.9199554213,30.6740972851,7 +0,10.0403482984,987.50416929,73.0472906209,7 +1,997.019228359,133.294717663,0.0572254083186,8 +0,973.303999107,1.79080888849,0.100478717048,8 +0,1008.28808825,342.282350685,0.409806485495,8 +0,1014.55621524,0.680510407082,0.929530602495,8 +1,1012.74370325,823.105266455,0.0894693730585,8 +0,1003.63554038,727.334432075,0.58206275756,8 +0,10.1560432436,740.35938307,11.6823378533,9 +0,9.83949099701,512.828227154,138.206666681,9 +1,10.1837395682,179.287126088,185.479062365,9 +1,9.9761881495,12.1093388336,9.1264604171,9 +1,9.77402180766,318.561317743,80.6005221355,9 +0,1011.15705381,0.215825852155,1.34429667906,10 +0,1005.60353229,727.202346126,1.47146041005,10 +1,1013.93702961,58.7312725205,0.421041560754,10 +0,1004.86813074,757.693204258,0.566055205344,10 +0,999.996324692,813.12386828,0.864428279513,10 +0,996.55255931,918.760056995,0.43365051974,10 +1,1004.1394132,464.371823646,0.312492288321,10 diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test b/jvm-packages/xgboost4j-spark/src/test/resources/rank.test.txt similarity index 100% rename from jvm-packages/xgboost4j-spark/src/test/resources/rank-demo.txt.test rename to jvm-packages/xgboost4j-spark/src/test/resources/rank.test.txt diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/rank.train.csv b/jvm-packages/xgboost4j-spark/src/test/resources/rank.train.csv new file mode 100644 index 000000000..ebe232b51 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/resources/rank.train.csv @@ -0,0 +1,149 @@ +0,985.574005058,320.223538037,0.621236086198,1 +0,1010.52917943,635.535543082,2.14984030531,1 +0,1012.91900422,132.387300057,0.488761066665,1 +0,990.829194034,135.102081162,0.747701610673,1 +0,1007.05103629,154.289183562,0.464118249201,1 +0,994.9573036,317.483732878,0.0313685555674,1 +0,987.8071541,731.349178363,0.244616944245,1 +1,10.0349544469,2.29750906143,36.4949974282,2 +0,9.92953881383,5.39134047297,120.041297548,2 +0,10.0909866713,9.06191026312,138.807825798,2 +1,10.2090970614,0.0784495944448,58.207703565,2 +0,9.85695905893,9.99500727713,56.8610243778,2 +1,10.0805758547,0.0410805760559,222.102302076,2 +0,10.1209914486,9.9729127088,171.888238763,2 +0,10.0331939798,0.853339303793,311.181328375,3 +0,9.93901762951,2.72757449146,78.4859514413,3 +0,10.0752365346,9.18695328235,49.8520256553,3 +1,10.0456548902,0.270936043122,123.462958597,3 +0,10.0568923673,0.82997113263,44.9391426001,3 +0,9.8214143472,0.277538931578,15.4217659578,3 +0,9.95258604431,8.69564346094,255.513470671,3 +0,9.91934976357,7.72809741413,82.171591817,3 +0,10.043239582,8.64168255553,38.9657919329,3 +1,10.0236147929,0.0496662263659,4.40889812286,3 +1,1001.85585324,3.75646886071,0.0179224994842,4 +0,1014.25578571,0.285765311201,0.510329864983,4 +1,1002.81422786,9.77676280375,0.433705951912,4 +1,998.072711553,2.82100686538,0.889829076909,4 +0,1003.77395036,2.55916592114,0.0359402151496,4 +1,10.0807877782,4.98513959013,47.5266363559,5 +0,10.0015013081,9.94302478763,78.3697486277,5 +1,10.0441936789,0.305091816635,56.8213984987,5 +0,9.94257106618,7.23909568913,442.463339039,5 +1,9.86479307916,6.41701315844,55.1365304834,5 +0,10.0428628516,9.98466447697,0.391632812588,5 +0,9.94445884566,9.99970945878,260.438436534,5 +1,9.84641392823,225.78051312,1.00525978847,6 +1,9.86907690608,26.8971083147,0.577959255991,6 +0,10.0177314626,0.110585342313,2.30545043031,6 +0,10.0688190907,412.023866234,1.22421542264,6 +0,10.1251769646,13.8212202925,0.129171734504,6 +0,10.0840758802,407.359097187,0.477000870705,6 +0,10.1007458705,987.183625145,0.149385677415,6 +0,9.86472656059,169.559640615,0.147221652519,6 +0,9.94207419238,507.290053755,0.41996207214,6 +0,9.9671005502,1.62610457716,0.408173666788,6 +0,1010.57126596,9.06673707562,0.672092284372,7 +0,1001.6718262,9.53203990055,4.7364050044,7 +0,995.777341384,4.43847316256,2.07229073634,7 +0,1002.95701386,5.51711016665,1.24294450546,7 +0,1016.0988238,0.626468941906,0.105627919134,7 +0,1013.67571419,0.042315529666,0.717619310322,7 +1,994.747747892,6.01989364024,0.772910130015,7 +1,991.654593872,7.35575736952,1.19822091548,7 +0,1008.47101732,8.28240754909,0.229582481359,7 +0,1000.81975227,1.52448354056,0.096441660362,7 +0,10.0900922344,322.656649307,57.8149073088,8 +1,10.0868337371,2.88652339174,54.8865514572,8 +0,10.0988984137,979.483832657,52.6809830901,8 +0,9.97678959238,665.770979738,481.069628909,8 +0,9.78554312773,257.309358658,47.7324475232,8 +0,10.0985967566,935.896512941,138.937052808,8 +0,10.0522252319,876.376299607,6.00373510669,8 +1,9.88065229501,9.99979825653,0.0674603696149,9 +0,10.0483244098,0.0653852316381,0.130679349938,9 +1,9.99685215607,1.76602542774,0.2551321159,9 +0,9.99750159428,1.01591534436,0.145445506504,9 +1,9.97380908941,0.940048645571,0.411805696316,9 +0,9.99977678382,6.91329929641,5.57858201258,9 +0,978.876096381,933.775364741,0.579170824236,10 +0,998.381016406,220.940470582,2.01491778565,10 +0,987.917644594,8.74667873567,0.364006099758,10 +0,1000.20994892,25.2945450565,3.5684398964,10 +0,1014.57141264,675.593540733,0.164174055535,10 +0,998.867283535,765.452750642,0.818425293238,10 +0,10.2143092481,273.576539531,137.111774354,11 +0,10.0366658918,842.469052609,2.32134375927,11 +0,10.1281202091,395.654057342,35.4184893063,11 +0,10.1443721289,960.058461049,272.887070637,11 +0,10.1353234784,535.51304462,2.15393842032,11 +1,10.0451640374,216.733858424,55.6533298016,11 +1,9.94254592171,44.5985537358,304.614176871,11 +0,10.1319257181,613.545504487,5.42391587912,11 +0,1020.63622468,997.476744201,0.509425590461,12 +0,986.304585519,822.669937965,0.605133561808,12 +1,1012.66863221,26.7185759069,0.0875458784828,12 +0,995.387656321,81.8540176995,0.691999430068,12 +0,1020.6587198,848.826964547,0.540159430526,12 +1,1003.81573853,379.84350931,0.0083682925194,12 +0,1021.60921516,641.376951467,1.12339054807,12 +0,1000.17585041,122.107138713,1.09906375372,12 +1,987.64802348,5.98448541152,0.124241987204,12 +1,9.94610136583,346.114985897,0.387708236565,13 +0,9.96812192337,313.278109696,0.00863026595671,13 +0,10.0181739194,36.7378924562,2.92179879835,13 +0,9.89000102695,164.273723971,0.685222591968,13 +0,10.1555212436,320.451459462,2.01341536261,13 +0,10.0085727613,999.767117646,0.462294934168,13 +1,9.93099658724,5.17478203909,0.213855205032,13 +0,10.0629454957,663.088181857,0.049022351462,13 +0,10.1109732417,734.904569784,1.6998450094,13 +0,1006.6015266,505.023453703,1.90870566777,14 +0,991.865769489,245.437343115,0.475109744256,14 +0,998.682734072,950.041057232,1.9256314201,14 +0,1005.02207209,2.9619314197,0.0517146822357,14 +0,1002.54526214,860.562681899,0.915687092848,14 +0,1000.38847359,808.416525088,0.209690673808,14 +1,992.557818382,373.889409453,0.107571728577,14 +0,1002.07722137,997.329626371,1.06504260496,14 +0,1000.40504333,949.832139189,0.539159980327,14 +0,10.1460179902,8.86082969819,135.953842715,15 +1,9.98529296553,2.87366448495,1.74249892194,15 +0,9.88942676744,9.4031821056,149.473066381,15 +1,10.0192953341,1.99685737576,1.79502473397,15 +0,10.0110654379,8.13112593726,87.7765628103,15 +0,997.148677047,733.936190093,1.49298494242,16 +0,1008.70465919,957.121652078,0.217414013634,16 +1,997.356154278,541.599587807,0.100855972216,16 +0,999.615897283,943.700501824,0.862874175879,16 +1,997.36859077,0.200859940848,0.13601892182,16 +0,10.0423255624,1.73855202168,0.956695338485,17 +1,9.88440755486,9.9994600678,0.305080529665,17 +0,10.0891026412,3.28031719474,0.364450973697,17 +0,9.90078644258,8.77839663617,0.456660574479,17 +1,9.79380029711,8.77220326156,0.527292005175,17 +0,9.93613887011,9.76270841268,1.40865693823,17 +0,10.0009239007,7.29056178263,0.498015866607,17 +0,9.96603319905,5.12498000925,0.517492532783,17 +0,10.0923827222,2.76652583955,1.56571226159,17 +1,10.0983782035,587.788120694,0.031756483687,18 +1,9.91397225464,994.527496819,3.72092164978,18 +0,10.1057472738,2.92894440088,0.683506438532,18 +0,10.1014053354,959.082038017,1.07039624129,18 +0,10.1433253044,322.515119317,0.51408278993,18 +1,9.82832510699,637.104433908,0.250272776427,18 +0,1000.49729075,2.75336888111,0.576634423274,19 +1,984.90338088,0.0295435794035,1.26273339929,19 +0,1001.53811442,4.64164410861,0.0293389959504,19 +1,995.875898395,5.08223403205,0.382330566779,19 +0,996.405937252,6.26395190757,0.453645816611,19 +0,10.0165140779,340.126072514,0.220794603312,20 +0,9.93482824816,951.672000448,0.124406293612,20 +0,10.1700278554,0.0140985961008,0.252452256311,20 +0,9.99825079542,950.382643896,0.875382402062,20 +0,9.87316410028,686.788257829,0.215886999825,20 +0,10.2893240654,89.3947931451,0.569578232133,20 +0,9.98689192703,0.430107535413,2.99869831728,20 +0,10.1365175107,972.279245093,0.0865099386744,20 +0,9.90744703306,50.810461183,3.00863325197,20 diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala index af012c4f5..5e148c2cf 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala @@ -21,37 +21,27 @@ import java.nio.file.Files import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkConf, SparkContext} -class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll { - var sc: SparkContext = _ - - override def beforeAll(): Unit = { - val conf: SparkConf = new SparkConf() - .setMaster("local[*]") - .setAppName("XGBoostSuite") - sc = new SparkContext(conf) - } +class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAll { private lazy val (model4, model8) = { - import DataUtils._ - val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache() + val training = buildDataFrame(Classification.train) val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic") - (XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, nWorkers = sc.defaultParallelism), - XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, nWorkers = sc.defaultParallelism)) + "objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism) + (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training), + new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training)) } test("test update/load models") { val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val manager = new CheckpointManager(sc, tmpPath) - manager.updateCheckpoint(model4) + manager.updateCheckpoint(model4._booster) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "4.model") assert(manager.loadCheckpointAsBooster.booster.getVersion == 4) - manager.updateCheckpoint(model8) + manager.updateCheckpoint(model8._booster) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "8.model") @@ -61,7 +51,7 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll { test("test cleanUpHigherVersions") { val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString val manager = new CheckpointManager(sc, tmpPath) - manager.updateCheckpoint(model8) + manager.updateCheckpoint(model8._booster) manager.cleanUpHigherVersions(round = 8) assert(new File(s"$tmpPath/8.model").exists()) @@ -74,7 +64,8 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll { val manager = new CheckpointManager(sc, tmpPath) assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7)) assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7)) - manager.updateCheckpoint(model4) + manager.updateCheckpoint(model4._booster) assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7)) } + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 192fba957..9e617aecf 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File +import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.spark.SparkContext -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql._ import org.scalatest.{BeforeAndAfterEach, FunSuite} trait PerTest extends BeforeAndAfterEach { self: FunSuite => + protected val numWorkers: Int = Runtime.getRuntime.availableProcessors() @transient private var currentSession: SparkSession = _ @@ -62,4 +64,30 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite => file.delete() } } + + protected def buildDataFrame( + labeledPoints: Seq[XGBLabeledPoint], + numPartitions: Int = numWorkers): DataFrame = { + import DataUtils._ + val it = labeledPoints.iterator.zipWithIndex + .map { case (labeledPoint: XGBLabeledPoint, id: Int) => + (id, labeledPoint.label, labeledPoint.features) + } + + ss.createDataFrame(sc.parallelize(it.toList, numPartitions)) + .toDF("id", "label", "features") + } + + protected def buildDataFrameWithGroup( + labeledPoints: Seq[XGBLabeledPoint], + numPartitions: Int = numWorkers): DataFrame = { + import DataUtils._ + val it = labeledPoints.iterator.zipWithIndex + .map { case (labeledPoint: XGBLabeledPoint, id: Int) => + (id, labeledPoint.label, labeledPoint.features, labeledPoint.group) + } + + ss.createDataFrame(sc.parallelize(it.toList, numPartitions)) + .toDF("id", "label", "features", "group") + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala new file mode 100644 index 000000000..cd9ef8c0f --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -0,0 +1,167 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import java.io.{File, FileNotFoundException} +import java.util.Arrays + +import ml.dmlc.xgboost4j.scala.DMatrix + +import scala.util.Random +import org.apache.spark.ml.feature._ +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.network.util.JavaUtils +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll { + + private var tempDir: File = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + tempDir = new File(System.getProperty("java.io.tmpdir"), this.getClass.getName) + if (tempDir.exists) { + tempDir.delete + } + tempDir.mkdirs + } + + override def afterAll(): Unit = { + JavaUtils.deleteRecursively(tempDir) + super.afterAll() + } + + private def delete(f: File) { + if (f.exists) { + if (f.isDirectory) { + for (c <- f.listFiles) { + delete(c) + } + } + if (!f.delete) { + throw new FileNotFoundException("Failed to delete file: " + f) + } + } + } + + test("test persistence of XGBoostClassifier and XGBoostClassificationModel") { + val eval = new EvalError() + val trainingDF = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers) + val xgbc = new XGBoostClassifier(paramMap) + val xgbcPath = new File(tempDir, "xgbc").getPath + xgbc.write.overwrite().save(xgbcPath) + val xgbc2 = XGBoostClassifier.load(xgbcPath) + val paramMap2 = xgbc2.MLlib2XGBoostParams + paramMap.foreach { + case (k, v) => assert(v.toString == paramMap2(k).toString) + } + + val model = xgbc.fit(trainingDF) + val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults < 0.1) + val xgbcModelPath = new File(tempDir, "xgbcModel").getPath + model.write.overwrite.save(xgbcModelPath) + val model2 = XGBoostClassificationModel.load(xgbcModelPath) + assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) + + assert(model.getEta === model2.getEta) + assert(model.getNumRound === model2.getNumRound) + assert(model.getRawPredictionCol === model2.getRawPredictionCol) + val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults === evalResults2) + } + + test("test persistence of XGBoostRegressor and XGBoostRegressionModel") { + val eval = new EvalError() + val trainingDF = buildDataFrame(Regression.train) + val testDM = new DMatrix(Regression.test.iterator) + + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear", "num_round" -> "10", "num_workers" -> numWorkers) + val xgbr = new XGBoostRegressor(paramMap) + val xgbrPath = new File(tempDir, "xgbr").getPath + xgbr.write.overwrite().save(xgbrPath) + val xgbr2 = XGBoostRegressor.load(xgbrPath) + val paramMap2 = xgbr2.MLlib2XGBoostParams + paramMap.foreach { + case (k, v) => assert(v.toString == paramMap2(k).toString) + } + + val model = xgbr.fit(trainingDF) + val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults < 0.1) + val xgbrModelPath = new File(tempDir, "xgbrModel").getPath + model.write.overwrite.save(xgbrModelPath) + val model2 = XGBoostRegressionModel.load(xgbrModelPath) + assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray)) + + assert(model.getEta === model2.getEta) + assert(model.getNumRound === model2.getNumRound) + assert(model.getPredictionCol === model2.getPredictionCol) + val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM) + assert(evalResults === evalResults2) + } + + test("test persistence of MLlib pipeline with XGBoostClassificationModel") { + + val r = new Random(0) + // maybe move to shared context, but requires session to import implicits + val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). + toDF("feature", "label") + + val assembler = new VectorAssembler() + .setInputCols(df.columns.filter(!_.contains("label"))) + .setOutputCol("features") + + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")) + val xgb = new XGBoostClassifier(paramMap) + + // Construct MLlib pipeline, save and load + val pipeline = new Pipeline().setStages(Array(assembler, xgb)) + val pipePath = new File(tempDir, "pipeline").getPath + pipeline.write.overwrite().save(pipePath) + val pipeline2 = Pipeline.read.load(pipePath) + val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier] + val paramMap2 = xgb2.MLlib2XGBoostParams + paramMap.foreach { + case (k, v) => assert(v.toString == paramMap2(k).toString) + } + + // Model training, save and load + val pipeModel = pipeline.fit(df) + val pipeModelPath = new File(tempDir, "pipelineModel").getPath + pipeModel.write.overwrite.save(pipeModelPath) + val pipeModel2 = PipelineModel.load(pipeModelPath) + + val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel] + val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel] + + assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray)) + + assert(xgbModel.getEta === xgbModel2.getEta) + assert(xgbModel.getNumRound === xgbModel2.getNumRound) + assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol) + } +} + diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala index a33443614..2a93170a6 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala @@ -16,8 +16,8 @@ package ml.dmlc.xgboost4j.scala.spark +import scala.collection.mutable import scala.io.Source - import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} trait TrainTestData { @@ -48,6 +48,17 @@ trait TrainTestData { XGBLabeledPoint(label, null, values) }.toList } + + protected def getLabeledPointsWithGroup(resource: String): Seq[XGBLabeledPoint] = { + getResourceLines(resource).map { line => + val original = line.split(",") + val length = original.length + val label = original.head.toFloat + val group = original.last.toInt + val values = original.slice(1, length - 1).map(_.toFloat) + XGBLabeledPoint(label, null, values, 1f, group, Float.NaN) + }.toList + } } object Classification extends TrainTestData { @@ -80,11 +91,8 @@ object Regression extends TrainTestData { } object Ranking extends TrainTestData { - val train0: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-0.txt.train", zeroBased = false) - val train1: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-1.txt.train", zeroBased = false) - val trainGroup0: Seq[Int] = getGroups("/rank-demo-0.txt.train.group") - val trainGroup1: Seq[Int] = getGroups("/rank-demo-1.txt.train.group") - val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo.txt.test", zeroBased = false) + val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv") + val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank.test.txt", zeroBased = false) private def getGroups(resource: String): Seq[Int] = { getResourceLines(resource).map(_.toInt).toList diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala new file mode 100644 index 000000000..4bb9e2c8c --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -0,0 +1,207 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql._ +import org.scalatest.FunSuite + +class XGBoostClassifierSuite extends FunSuite with PerTest { + + test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") { + val trainingDM = new DMatrix(Classification.train.iterator) + val testDM = new DMatrix(Classification.test.iterator) + val trainingDF = buildDataFrame(Classification.train) + val testDF = buildDataFrame(Classification.test) + val round = 5 + + val paramMap = Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "binary:logistic") + + val model1 = ScalaXGBoost.train(trainingDM, paramMap, round) + val prediction1 = model1.predict(testDM) + + val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round, + "num_workers" -> numWorkers)).fit(trainingDF) + + val prediction2 = model2.transform(testDF). + collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap + + assert(testDF.count() === prediction2.size) + // the vector length in probability column is 2 since we have to fit to the evaluator in Spark + for (i <- prediction1.indices) { + assert(prediction1(i).length === prediction2(i).values.length - 1) + for (j <- prediction1(i).indices) { + assert(prediction1(i)(j) === prediction2(i)(j + 1)) + } + } + + val prediction3 = model1.predict(testDM, outPutMargin = true) + val prediction4 = model2.transform(testDF). + collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap + + assert(testDF.count() === prediction4.size) + for (i <- prediction3.indices) { + assert(prediction3(i).length === prediction4(i).values.length) + for (j <- prediction3(i).indices) { + assert(prediction3(i)(j) === prediction4(i)(j)) + } + } + } + + test("Set params in XGBoost and MLlib way should produce same model") { + val trainingDF = buildDataFrame(Classification.train) + val testDF = buildDataFrame(Classification.test) + val round = 5 + + val paramMap = Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "binary:logistic", + "num_round" -> round, + "num_workers" -> numWorkers) + + // Set params in XGBoost way + val model1 = new XGBoostClassifier(paramMap).fit(trainingDF) + // Set params in MLlib way + val model2 = new XGBoostClassifier() + .setEta(1) + .setMaxDepth(6) + .setSilent(1) + .setObjective("binary:logistic") + .setNumRound(round) + .setNumWorkers(numWorkers) + .fit(trainingDF) + + val prediction1 = model1.transform(testDF).select("prediction").collect() + val prediction2 = model2.transform(testDF).select("prediction").collect() + + prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) => + assert(p1 === p2) + } + } + + test("test schema of XGBoostClassificationModel") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers) + val trainingDF = buildDataFrame(Classification.train) + val testDF = buildDataFrame(Classification.test) + + val model = new XGBoostClassifier(paramMap).fit(trainingDF) + + model.setRawPredictionCol("raw_prediction") + .setProbabilityCol("probability_prediction") + .setPredictionCol("final_prediction") + var predictionDF = model.transform(testDF) + assert(predictionDF.columns.contains("id")) + assert(predictionDF.columns.contains("features")) + assert(predictionDF.columns.contains("label")) + assert(predictionDF.columns.contains("raw_prediction")) + assert(predictionDF.columns.contains("probability_prediction")) + assert(predictionDF.columns.contains("final_prediction")) + model.setRawPredictionCol("").setPredictionCol("final_prediction") + predictionDF = model.transform(testDF) + assert(predictionDF.columns.contains("raw_prediction") === false) + assert(predictionDF.columns.contains("final_prediction")) + model.setRawPredictionCol("raw_prediction").setPredictionCol("") + predictionDF = model.transform(testDF) + assert(predictionDF.columns.contains("raw_prediction")) + assert(predictionDF.columns.contains("final_prediction") === false) + + assert(model.summary.trainObjectiveHistory.length === 5) + assert(model.summary.testObjectiveHistory.isEmpty) + } + + test("XGBoost and Spark parameters synchronize correctly") { + val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic") + // from xgboost params to spark params + val xgb = new XGBoostClassifier(xgbParamMap) + assert(xgb.getEta === 1.0) + assert(xgb.getObjective === "binary:logistic") + // from spark to xgboost params + val xgbCopy = xgb.copy(ParamMap.empty) + assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0) + assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic") + val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss")) + assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss") + } + + test("multi class classification") { + val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5, + "num_workers" -> numWorkers) + val trainingDF = buildDataFrame(MultiClassification.train) + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(trainingDF) + assert(model.getEta == 0.1) + assert(model.getMaxDepth == 6) + assert(model.numClasses == 6) + } + + test("use base margin") { + val training1 = buildDataFrame(Classification.train) + val training2 = training1.withColumn("margin", functions.rand()) + val test = buildDataFrame(Classification.test) + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "test_train_split" -> "0.5", + "num_round" -> 5, "num_workers" -> numWorkers) + + val xgb = new XGBoostClassifier(paramMap) + val model1 = xgb.fit(training1) + val model2 = xgb.setBaseMarginCol("margin").fit(training2) + val prediction1 = model1.transform(test).select(model1.getProbabilityCol) + .collect().map(row => row.getAs[Vector](0)) + val prediction2 = model2.transform(test).select(model2.getProbabilityCol) + .collect().map(row => row.getAs[Vector](0)) + var count = 0 + for ((r1, r2) <- prediction1.zip(prediction2)) { + if (!r1.equals(r2)) count = count + 1 + } + assert(count != 0) + } + + test("training summary") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers) + + val trainingDF = buildDataFrame(Classification.train) + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(trainingDF) + + assert(model.summary.trainObjectiveHistory.length === 5) + assert(model.summary.testObjectiveHistory.isEmpty) + } + + test("train/test split") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", + "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Classification.train) + + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(training) + val Some(testObjectiveHistory) = model.summary.testObjectiveHistory + assert(testObjectiveHistory.length === 5) + assert(model.summary.trainObjectiveHistory !== testObjectiveHistory) + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala index c216ab257..3b9eae707 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala @@ -17,36 +17,34 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} - -import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql._ import org.scalatest.FunSuite class XGBoostConfigureSuite extends FunSuite with PerTest { + override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .config("spark.kryo.classesToRegister", classOf[Booster].getName) test("nthread configuration must be no larger than spark.task.cpus") { + val training = buildDataFrame(Classification.train) val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic", + "objective" -> "binary:logistic", "num_workers" -> numWorkers, "nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1)) intercept[IllegalArgumentException] { - XGBoost.trainWithRDD(sc.parallelize(List()), paramMap, 5, numWorkers) + new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training) } } test("kryoSerializer test") { - import DataUtils._ // TODO write an isolated test for Booster. - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator, null) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator, null) val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic") - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers) + + val model = new XGBoostClassifier(paramMap).fit(training) val eval = new EvalError() - assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) < 0.1) + assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala deleted file mode 100644 index 4cdcaaf39..000000000 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ /dev/null @@ -1,265 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} -import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} -import org.apache.spark.ml.linalg.DenseVector -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DataTypes -import org.scalatest.FunSuite -import org.scalatest.prop.TableDrivenPropertyChecks - -class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks { - private def buildDataFrame( - labeledPoints: Seq[XGBLabeledPoint], - numPartitions: Int = numWorkers): DataFrame = { - import DataUtils._ - val it = labeledPoints.iterator.zipWithIndex - .map { case (labeledPoint: XGBLabeledPoint, id: Int) => - (id, labeledPoint.label, labeledPoint.features) - } - - ss.createDataFrame(sc.parallelize(it.toList, numPartitions)) - .toDF("id", "label", "features") - } - - test("test consistency and order preservation of dataframe-based model") { - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic") - val trainingItr = Classification.train.iterator - val testItr = Classification.test.iterator - val round = 5 - val trainDMatrix = new DMatrix(trainingItr) - val testDMatrix = new DMatrix(testItr) - val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round) - val predResultFromSeq = xgboostModel.predict(testDMatrix) - val trainingDF = buildDataFrame(Classification.train) - val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = round, nWorkers = numWorkers) - val testDF = buildDataFrame(Classification.test) - val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF). - collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap - assert(testDF.count() === predResultsFromDF.size) - // the vector length in probabilties column is 2 since we have to fit to the evaluator in - // Spark - for (i <- predResultFromSeq.indices) { - assert(predResultFromSeq(i).length === predResultsFromDF(i).values.length - 1) - for (j <- predResultFromSeq(i).indices) { - assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j + 1)) - } - } - } - - test("test transformLeaf") { - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic") - val trainingDF = buildDataFrame(Classification.train) - val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = 5, nWorkers = numWorkers) - val testDF = buildDataFrame(Classification.test) - xgBoostModelWithDF.transformLeaf(testDF).show() - } - - test("test schema of XGBoostRegressionModel") { - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:linear") - val trainingDF = buildDataFrame(Regression.train) - val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = 5, nWorkers = numWorkers, useExternalMemory = true) - xgBoostModelWithDF.setPredictionCol("final_prediction") - val testDF = buildDataFrame(Regression.test) - val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF) - assert(predictionDF.columns.contains("id")) - assert(predictionDF.columns.contains("features")) - assert(predictionDF.columns.contains("label")) - assert(predictionDF.columns.contains("final_prediction")) - predictionDF.show() - } - - test("test schema of XGBoostClassificationModel") { - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic") - val trainingDF = buildDataFrame(Classification.train) - val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = 5, nWorkers = numWorkers, useExternalMemory = true) - xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol( - "raw_prediction").setPredictionCol("final_prediction") - val testDF = buildDataFrame(Classification.test) - var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF) - assert(predictionDF.columns.contains("id")) - assert(predictionDF.columns.contains("features")) - assert(predictionDF.columns.contains("label")) - assert(predictionDF.columns.contains("raw_prediction")) - assert(predictionDF.columns.contains("final_prediction")) - xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(""). - setPredictionCol("final_prediction") - predictionDF = xgBoostModelWithDF.transform(testDF) - assert(predictionDF.columns.contains("id")) - assert(predictionDF.columns.contains("features")) - assert(predictionDF.columns.contains("label")) - assert(predictionDF.columns.contains("raw_prediction") === false) - assert(predictionDF.columns.contains("final_prediction")) - xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel]. - setRawPredictionCol("raw_prediction").setPredictionCol("") - predictionDF = xgBoostModelWithDF.transform(testDF) - assert(predictionDF.columns.contains("id")) - assert(predictionDF.columns.contains("features")) - assert(predictionDF.columns.contains("label")) - assert(predictionDF.columns.contains("raw_prediction")) - assert(predictionDF.columns.contains("final_prediction") === false) - } - - test("xgboost and spark parameters synchronize correctly") { - val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic") - // from xgboost params to spark params - val xgbEstimator = new XGBoostEstimator(xgbParamMap) - assert(xgbEstimator.get(xgbEstimator.eta).get === 1.0) - assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic") - // from spark to xgboost params - val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty) - assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eta").toString.toDouble === 1.0) - assert(xgbEstimatorCopy.fromParamsToXGBParamMap("objective").toString === "binary:logistic") - } - - test("eval_metric is configured correctly") { - val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic") - val xgbEstimator = new XGBoostEstimator(xgbParamMap) - assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error") - val sparkParamMap = ParamMap.empty - val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap) - assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eval_metric") === "error") - val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss")) - assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss") - } - - ignore("fast histogram algorithm parameters are exposed correctly") { - val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0", - "objective" -> "binary:logistic", "tree_method" -> "hist", - "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2", - "eval_metric" -> "error") - val testItr = Classification.test.iterator - val trainingDF = buildDataFrame(Classification.train) - val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = 10, nWorkers = math.min(2, numWorkers)) - val error = new EvalError - val testSetDMatrix = new DMatrix(testItr) - assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) < 0.1) - } - - test("multi_class classification test") { - val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "multi:softmax", "num_class" -> "6") - val trainingDF = buildDataFrame(MultiClassification.train) - XGBoost.trainWithDataFrame(trainingDF.toDF(), paramMap, round = 5, nWorkers = numWorkers) - } - - test("test DF use nested groupData") { - val trainingDF = buildDataFrame(Ranking.train0, 1) - .union(buildDataFrame(Ranking.train1, 1)) - val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1) - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "rank:pairwise", "groupData" -> trainGroupData) - - val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = 5, nWorkers = 2) - val testDF = buildDataFrame(Ranking.test) - val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF). - collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap - assert(testDF.count() === predResultsFromDF.size) - } - - test("params of estimator and produced model are coordinated correctly") { - val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "multi:softmax", "num_class" -> "6") - val trainingDF = buildDataFrame(MultiClassification.train) - val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers) - assert(model.get[Double](model.eta).get == 0.1) - assert(model.get[Int](model.maxDepth).get == 6) - assert(model.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6) - } - - test("test use base margin") { - import DataUtils._ - val trainingDf = buildDataFrame(Classification.train) - val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand()) - val testRDD = sc.parallelize(Classification.test.map(_.features)) - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic", "baseMarginCol" -> "margin", - "testTrainSplit" -> 0.5) - - def trainPredict(df: Dataset[_]): Array[Float] = { - XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers) - .predict(testRDD) - .map { case Array(p) => p } - .collect() - } - - val pred = trainPredict(trainingDf) - val predWithMargin = trainPredict(trainingDfWithMargin) - assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm }) - } - - test("test use weight") { - import DataUtils._ - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:linear", "weightCol" -> "weight") - - val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType) - val trainingDF = buildDataFrame(Regression.train) - .withColumn("weight", getWeightFromId(col("id"))) - - val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, - nWorkers = numWorkers, useExternalMemory = true) - .setPredictionCol("final_prediction") - .setExternalMemory(true) - val testRDD = sc.parallelize(Regression.test.map(_.features)) - val predictions = model.predict(testRDD).collect().flatten - - // The predictions heavily relies on the first training instance, and thus are very close. - predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f)) - } - - test("training summary") { - val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic").toMap - - val trainingDf = buildDataFrame(Classification.train) - val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5, - nWorkers = numWorkers) - - assert(model.summary.trainObjectiveHistory.length === 5) - assert(model.summary.testObjectiveHistory.isEmpty) - } - - test("train/test split") { - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic", "trainTestRatio" -> "0.5") - val trainingDf = buildDataFrame(Classification.train) - - forAll(Table("useExternalMemory", false, true)) { useExternalMemory => - val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5, - nWorkers = numWorkers, useExternalMemory = useExternalMemory) - val Some(testObjectiveHistory) = model.summary.testObjectiveHistory - assert(testObjectiveHistory.length === 5) - assert(model.summary.trainObjectiveHistory !== testObjectiveHistory) - } - } -} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 64f8b8ca2..05e9ae757 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -18,19 +18,18 @@ package ml.dmlc.xgboost4j.scala.spark import java.nio.file.Files import java.util.concurrent.LinkedBlockingDeque - -import scala.util.Random import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.rabit.RabitTracker +import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext -import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} -import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector} -import org.apache.spark.rdd.RDD +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.sql._ import org.scalatest.FunSuite +import scala.util.Random class XGBoostGeneralSuite extends FunSuite with PerTest { + test("test Rabit allreduce to validate Scala-implemented Rabit tracker") { val vectorLength = 100 val rdd = sc.parallelize( @@ -87,283 +86,153 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { } test("training with external memory cache") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) - val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic").toMap - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = numWorkers, useExternalMemory = true) - assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) < 0.1) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers, + "use_external_memory" -> true) + val model = new XGBoostClassifier(paramMap).fit(training) + assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) } + test("training with Scala-implemented Rabit tracker") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) - val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic", - "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = numWorkers) - assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) < 0.1) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")) + val model = new XGBoostClassifier(paramMap).fit(training) + assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) } + ignore("test with fast histo depthwise") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic", "tree_method" -> "hist", - "grow_policy" -> "depthwise", "eval_metric" -> "error") + "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise", + "eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2)) // TODO: histogram algorithm seems to be very very sensitive to worker number - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = math.min(numWorkers, 2)) - assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) < 0.1) + val model = new XGBoostClassifier(paramMap).fit(training) + assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) } ignore("test with fast histo lossguide") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1", - "objective" -> "binary:logistic", "tree_method" -> "hist", - "grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error") - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = math.min(numWorkers, 2)) - val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) + "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide", + "max_leaves" -> "8", "eval_metric" -> "error", "num_round" -> 5, + "num_workers" -> math.min(numWorkers, 2)) + val model = new XGBoostClassifier(paramMap).fit(training) + val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(x < 0.1) } ignore("test with fast histo lossguide with max bin") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0", - "objective" -> "binary:logistic", "tree_method" -> "hist", - "grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16", - "eval_metric" -> "error") - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = math.min(numWorkers, 2)) - val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) + "objective" -> "binary:logistic", "tree_method" -> "hist", + "grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16", + "eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2)) + val model = new XGBoostClassifier(paramMap).fit(training) + val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(x < 0.1) } ignore("test with fast histo depthwidth with max depth") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0", "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2", - "eval_metric" -> "error") - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10, - nWorkers = math.min(numWorkers, 2)) - val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) + "eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2)) + val model = new XGBoostClassifier(paramMap).fit(training) + val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(x < 0.1) } ignore("test with fast histo depthwidth with max depth and max bin") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0", - "objective" -> "binary:logistic", "tree_method" -> "hist", - "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2", - "eval_metric" -> "error") - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10, - nWorkers = math.min(numWorkers, 2)) - val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) + "objective" -> "binary:logistic", "tree_method" -> "hist", + "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2", + "eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2)) + val model = new XGBoostClassifier(paramMap).fit(training) + val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(x < 0.1) } - test("test with dense vectors containing missing value") { - def buildDenseRDD(): RDD[MLLabeledPoint] = { + test("dense vectors containing missing value") { + def buildDenseDataFrame(): DataFrame = { val numRows = 100 val numCols = 5 - val labeledPoints = (0 until numRows).map { _ => - val label = Random.nextDouble() + val data = (0 until numRows).map { x => + val label = Random.nextInt(2) val values = Array.tabulate[Double](numCols) { c => - if (c == numCols - 1) -0.1 else Random.nextDouble() + if (c == numCols - 1) -0.1 else Random.nextDouble } - MLLabeledPoint(label, Vectors.dense(values)) + (label, Vectors.dense(values)) } - sc.parallelize(labeledPoints) + ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features") } - val trainingRDD = buildDenseRDD().repartition(4) - val testRDD = buildDenseRDD().repartition(4).map(_.features.asInstanceOf[DenseVector]) + val denseDF = buildDenseDataFrame().repartition(4) val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic").toMap - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers, - useExternalMemory = true) - xgBoostModel.predict(testRDD, missingValue = -0.1f).collect() - } - - test("test consistency of prediction functions with RDD") { - import DataUtils._ - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSet = Classification.test - val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features) - val testCollection = testRDD.collect() - for (i <- testSet.indices) { - assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values)) - } - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic") - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) - val predRDD = xgBoostModel.predict(testRDD) - val predResult1 = predRDD.collect() - assert(testRDD.count() === predResult1.length) - val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator)) - for (i <- predResult1.indices; j <- predResult1(i).indices) { - assert(predResult1(i)(j) === predResult2(i)(j)) - } - } - - test("test eval functions with RDD") { - import DataUtils._ - val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache() - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic") - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers) - // Nan Zhu: deprecate it for now - // xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false) - xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false) - } - - test("test prediction functionality with empty partition") { - import DataUtils._ - def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = { - sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers) - } - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testRDD = buildEmptyRDD() - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic").toMap - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) - println(xgBoostModel.predict(testRDD).collect().length === 0) - } - - test("test use groupData") { - import DataUtils._ - val trainingRDD = sc.parallelize(Ranking.train0, numSlices = 1).map(_.asML) - val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0) - val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features) - - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "rank:pairwise", "eval_metric" -> "ndcg", "groupData" -> trainGroupData) - - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 2, nWorkers = 1) - val predRDD = xgBoostModel.predict(testRDD) - val predResult1: Array[Array[Float]] = predRDD.collect() - assert(testRDD.count() === predResult1.length) - - val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData) - assert(avgMetric contains "ndcg") - // If the labels were lost ndcg comes back as 1.0 - assert(avgMetric.split('=')(1).toFloat < 1F) - } - - test("test use nested groupData") { - import DataUtils._ - val trainingRDD0 = sc.parallelize(Ranking.train0, numSlices = 1) - val trainingRDD1 = sc.parallelize(Ranking.train1, numSlices = 1) - val trainingRDD = trainingRDD0.union(trainingRDD1).map(_.asML) - - val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1) - - val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features) - - val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "rank:pairwise", "groupData" -> trainGroupData) - - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2) - val predRDD = xgBoostModel.predict(testRDD) - val predResult1: Array[Array[Float]] = predRDD.collect() - assert(testRDD.count() === predResult1.length) + "objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap + val model = new XGBoostClassifier(paramMap).fit(denseDF) + model.transform(denseDF).collect() } test("training with spark parallelism checks disabled") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) - val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic", "timeout_request_workers" -> 0L).toMap - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = numWorkers) - assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) < 0.1) - } - - test("isClassificationTask correctly classifies supported objectives") { - import org.scalatest.prop.TableDrivenPropertyChecks._ - - val objectives = Table( - ("isClassificationTask", "params"), - (true, Map("obj_type" -> "classification")), - (false, Map("obj_type" -> "regression")), - (false, Map("objective" -> "rank:ndcg")), - (false, Map("objective" -> "rank:pairwise")), - (false, Map("objective" -> "rank:map")), - (false, Map("objective" -> "count:poisson")), - (true, Map("objective" -> "binary:logistic")), - (true, Map("objective" -> "binary:logitraw")), - (true, Map("objective" -> "multi:softmax")), - (true, Map("objective" -> "multi:softprob")), - (false, Map("objective" -> "reg:linear")), - (false, Map("objective" -> "reg:logistic")), - (false, Map("objective" -> "reg:gamma")), - (false, Map("objective" -> "reg:tweedie"))) - forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) => - assert(XGBoost.isClassificationTask(params) == isClassificationTask) - } + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "timeout_request_workers" -> 0L, + "num_round" -> 5, "num_workers" -> numWorkers) + val model = new XGBoostClassifier(paramMap).fit(training) + val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) + assert(x < 0.1) } test("training with checkpoint boosters") { - import DataUtils._ val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString - val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1", + val paramMap = Map("eta" -> "1", "max_depth" -> 2, "silent" -> "1", "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, - "checkpoint_interval" -> 2).toMap - val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = numWorkers) - def error(model: XGBoostModel): Float = eval.eval( - model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix) + "checkpoint_interval" -> 2, "num_workers" -> numWorkers) + + val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training) + def error(model: Booster): Float = eval.eval( + model.predict(testDM, outPutMargin = true), testDM) // Check only one model is kept after training val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "8.model") - val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model") + val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") // Train next model based on prev model - val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8, - nWorkers = numWorkers) - assert(error(tmpModel) > error(prevModel)) - assert(error(prevModel) > error(nextModel)) - assert(error(nextModel) < 0.1) + val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) + assert(error(tmpModel) > error(prevModel._booster)) + assert(error(prevModel._booster) > error(nextModel._booster)) + assert(error(nextModel._booster) < 0.1) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModelSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModelSuite.scala deleted file mode 100644 index a8d0d3773..000000000 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModelSuite.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import java.nio.file.Files - -import ml.dmlc.xgboost4j.scala.DMatrix -import org.apache.spark.ml.linalg.Vector -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.rdd.RDD -import org.scalatest.FunSuite - -class XGBoostModelSuite extends FunSuite with PerTest { - test("test model consistency after save and load") { - import DataUtils._ - val eval = new EvalError() - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testSetDMatrix = new DMatrix(Classification.test.iterator) - val tempDir = Files.createTempDirectory("xgboosttest-") - val tempFile = Files.createTempFile(tempDir, "", "") - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic") - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) - val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) - assert(evalResults < 0.1) - xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) - val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) - val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true) - val loadedEvalResults = eval.eval(predicts, testSetDMatrix) - assert(loadedEvalResults == evalResults) - } - - test("test save and load of different types of models") { - import DataUtils._ - val tempDir = Files.createTempDirectory("xgboosttest-") - val tempFile = Files.createTempFile(tempDir, "", "") - var trainingRDD = sc.parallelize(Classification.train).map(_.asML) - var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "reg:linear") - // validate regression model - var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = numWorkers, useExternalMemory = false) - xgBoostModel.setFeaturesCol("feature_col") - xgBoostModel.setLabelCol("label_col") - xgBoostModel.setPredictionCol("prediction_col") - xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) - var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) - assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel]) - assert(loadedXGBoostModel.getFeaturesCol == "feature_col") - assert(loadedXGBoostModel.getLabelCol == "label_col") - assert(loadedXGBoostModel.getPredictionCol == "prediction_col") - // classification model - paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "binary:logistic") - xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = numWorkers, useExternalMemory = false) - xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col") - xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5)) - xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) - loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) - assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel]) - assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol == - "raw_col") - assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep == - Array(0.5, 0.5).deep) - assert(loadedXGBoostModel.getFeaturesCol == "features") - assert(loadedXGBoostModel.getLabelCol == "label") - assert(loadedXGBoostModel.getPredictionCol == "prediction") - // (multiclass) classification model - trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML) - paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "multi:softmax", "num_class" -> "6") - xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, - nWorkers = numWorkers, useExternalMemory = false) - xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col") - xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds( - Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5)) - xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) - loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) - assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel]) - assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol == - "raw_col") - assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep == - Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep) - assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6) - assert(loadedXGBoostModel.getFeaturesCol == "features") - assert(loadedXGBoostModel.getLabelCol == "label") - assert(loadedXGBoostModel.getPredictionCol == "prediction") - } - - test("copy and predict ClassificationModel") { - import DataUtils._ - val trainingRDD = sc.parallelize(Classification.train).map(_.asML) - val testRDD = sc.parallelize(Classification.test).map(_.features) - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic") - val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) - testCopy(model, testRDD) - } - - test("copy and predict RegressionModel") { - import DataUtils._ - val trainingRDD = sc.parallelize(Regression.train).map(_.asML) - val testRDD = sc.parallelize(Regression.test).map(_.features) - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "reg:linear") - val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) - testCopy(model, testRDD) - } - - private def testCopy(model: XGBoostModel, testRDD: RDD[Vector]): Unit = { - val modelCopy = model.copy(ParamMap.empty) - modelCopy.summary // Ensure no exception. - - val expected = model.predict(testRDD).collect - assert(modelCopy.predict(testRDD).collect === expected) - } -} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala new file mode 100644 index 000000000..86aa96d57 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -0,0 +1,114 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.scalatest.FunSuite + +class XGBoostRegressorSuite extends FunSuite with PerTest { + + test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") { + val trainingDM = new DMatrix(Regression.train.iterator) + val testDM = new DMatrix(Regression.test.iterator) + val trainingDF = buildDataFrame(Regression.train) + val testDF = buildDataFrame(Regression.test) + val round = 5 + + val paramMap = Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "reg:linear") + + val model1 = ScalaXGBoost.train(trainingDM, paramMap, round) + val prediction1 = model1.predict(testDM) + + val model2 = new XGBoostRegressor(paramMap ++ Array("num_round" -> round, + "num_workers" -> numWorkers)).fit(trainingDF) + + val prediction2 = model2.transform(testDF). + collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap + + assert(prediction1.indices.count { i => + math.abs(prediction1(i)(0) - prediction2(i)) > 0.01 + } < prediction1.length * 0.1) + } + + test("Set params in XGBoost and MLlib way should produce same model") { + val trainingDF = buildDataFrame(Regression.train) + val testDF = buildDataFrame(Regression.test) + val round = 5 + + val paramMap = Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "reg:linear", + "num_round" -> round, + "num_workers" -> numWorkers) + + // Set params in XGBoost way + val model1 = new XGBoostRegressor(paramMap).fit(trainingDF) + // Set params in MLlib way + val model2 = new XGBoostRegressor() + .setEta(1) + .setMaxDepth(6) + .setSilent(1) + .setObjective("reg:linear") + .setNumRound(round) + .setNumWorkers(numWorkers) + .fit(trainingDF) + + val prediction1 = model1.transform(testDF).select("prediction").collect() + val prediction2 = model2.transform(testDF).select("prediction").collect() + + prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) => + assert(math.abs(p1 - p2) <= 0.01f) + } + } + + test("ranking: use group data") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5, + "group_col" -> "group") + + val trainingDF = buildDataFrameWithGroup(Ranking.train) + val testDF = buildDataFrame(Ranking.test) + val model = new XGBoostRegressor(paramMap).fit(trainingDF) + + val prediction = model.transform(testDF).collect() + assert(testDF.count() === prediction.length) + } + + test("use weight") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers) + + val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType) + val trainingDF = buildDataFrame(Regression.train) + .withColumn("weight", getWeightFromId(col("id"))) + val testDF = buildDataFrame(Regression.test) + + val model = new XGBoostRegressor(paramMap).setWeightCol("weight").fit(trainingDF) + val prediction = model.transform(testDF).collect() + val first = prediction.head.getAs[Double]("prediction") + prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f)) + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala deleted file mode 100644 index f984a9da3..000000000 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSparkPipelinePersistence.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import java.io.{File, FileNotFoundException} - -import scala.util.Random - -import org.apache.spark.SparkConf -import org.apache.spark.ml.feature._ -import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.sql.SparkSession -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -class XGBoostSparkPipelinePersistence extends FunSuite with PerTest - with BeforeAndAfterAll { - - override def afterAll(): Unit = { - delete(new File("./testxgbPipe")) - delete(new File("./testxgbEst")) - delete(new File("./testxgbModel")) - delete(new File("./test2xgbModel")) - } - - private def delete(f: File) { - if (f.exists()) { - if (f.isDirectory()) { - for (c <- f.listFiles()) { - delete(c) - } - } - if (!f.delete()) { - throw new FileNotFoundException("Failed to delete file: " + f) - } - } - } - - test("test persistence of XGBoostEstimator") { - val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "multi:softmax", "num_class" -> "6") - val xgbEstimator = new XGBoostEstimator(paramMap) - xgbEstimator.write.overwrite().save("./testxgbEst") - val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst") - val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap - paramMap.foreach { - case (k, v) => assert(v == loadedParamMap(k).toString) - } - } - - test("test persistence of a complete pipeline") { - val conf = new SparkConf().setAppName("foo").setMaster("local[*]") - val spark = SparkSession.builder().config(conf).getOrCreate() - val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "multi:softmax", "num_class" -> "6") - val r = new Random(0) - val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features") - val xgbEstimator = new XGBoostEstimator(paramMap) - val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator)) - pipeline.write.overwrite().save("testxgbPipe") - val loadedPipeline = Pipeline.read.load("testxgbPipe") - val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator] - val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap - paramMap.foreach { - case (k, v) => assert(v == loadedParamMap(k).toString) - } - } - - test("test persistence of XGBoostModel") { - val conf = new SparkConf().setAppName("foo").setMaster("local[*]") - val spark = SparkSession.builder().config(conf).getOrCreate() - val r = new Random(0) - // maybe move to shared context, but requires session to import implicits - val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))). - toDF("feature", "label") - val vectorAssembler = new VectorAssembler() - .setInputCols(df.columns - .filter(!_.contains("label"))) - .setOutputCol("features") - val xgbEstimator = new XGBoostEstimator(Map("num_round" -> 10, - "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala") - )).setFeaturesCol("features").setLabelCol("label") - // separate - val predModel = xgbEstimator.fit(vectorAssembler.transform(df)) - predModel.write.overwrite.save("test2xgbModel") - val same2Model = XGBoostModel.load("test2xgbModel") - - assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray)) - val predParamMap = predModel.extractParamMap() - val same2ParamMap = same2Model.extractParamMap() - assert(predParamMap.get(predModel.useExternalMemory) - === same2ParamMap.get(same2Model.useExternalMemory)) - assert(predParamMap.get(predModel.featuresCol) === same2ParamMap.get(same2Model.featuresCol)) - assert(predParamMap.get(predModel.predictionCol) - === same2ParamMap.get(same2Model.predictionCol)) - assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol)) - assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol)) - - // chained - val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df) - predictionModel.write.overwrite.save("testxgbModel") - val sameModel = PipelineModel.load("testxgbModel") - - val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head - val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head - - assert(java.util.Arrays.equals( - predictionModelXGB.booster.toByteArray, - sameModelXGB.booster.toByteArray - )) - val predictionModelXGBParamMap = predictionModel.extractParamMap() - val sameModelXGBParamMap = sameModel.extractParamMap() - assert(predictionModelXGBParamMap.get(predictionModelXGB.useExternalMemory) - === sameModelXGBParamMap.get(sameModelXGB.useExternalMemory)) - assert(predictionModelXGBParamMap.get(predictionModelXGB.featuresCol) - === sameModelXGBParamMap.get(sameModelXGB.featuresCol)) - assert(predictionModelXGBParamMap.get(predictionModelXGB.predictionCol) - === sameModelXGBParamMap.get(sameModelXGB.predictionCol)) - assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol) - === sameModelXGBParamMap.get(sameModelXGB.labelCol)) - assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol) - === sameModelXGBParamMap.get(sameModelXGB.labelCol)) - } -} - diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 5b74fb247..f885a6881 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -197,6 +197,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) booster.getFeatureScore(featureMap).asScala } + def getVersion: Int = booster.getVersion + def toByteArray: Array[Byte] = { booster.toByteArray }