diff --git a/.travis.yml b/.travis.yml index af06f2f27..c777d40fd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -49,7 +49,7 @@ addons: before_install: - source dmlc-core/scripts/travis/travis_setup_env.sh - export PYTHONPATH=${PYTHONPATH}:${PWD}/python-package - - echo "MAVEN_OPTS='-Xmx2048m -XX:MaxPermSize=1024m -XX:ReservedCodeCacheSize=512m'" > ~/.mavenrc + - echo "MAVEN_OPTS='-Xmx2048m -XX:MaxPermSize=1024m -XX:ReservedCodeCacheSize=512m -Dorg.slf4j.simpleLogger.defaultLogLevel=error'" > ~/.mavenrc install: - source tests/travis/setup.sh diff --git a/jvm-packages/create_jni.sh b/jvm-packages/create_jni.sh index 81d0d0992..0c6469b9a 100755 --- a/jvm-packages/create_jni.sh +++ b/jvm-packages/create_jni.sh @@ -31,6 +31,11 @@ mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl} cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py # copy test data files mkdir -p xgboost4j-spark/src/test/resources/ +cd ../demo/regression +python mapfeat.py +python mknfold.py machine.txt 1 +cd - +cp ../demo/regression/machine.txt.t* xgboost4j-spark/src/test/resources/ cp ../demo/data/agaricus.* xgboost4j-spark/src/test/resources/ popd > /dev/null echo "complete" diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala index b731a0b2d..851cffea9 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkWithRDD.scala @@ -20,6 +20,8 @@ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector} +import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} object SparkWithRDD { def main(args: Array[String]): Unit = { @@ -38,8 +40,10 @@ object SparkWithRDD { // number of iterations val numRound = args(0).toInt import DataUtils._ - val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath) - val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath).collect().iterator + val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).map(lp => + MLLabeledPoint(lp.label, new MLDenseVector(lp.features.toArray))) + val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath).collect().map( + lp => new MLDenseVector(lp.features.toArray)).iterator // training parameters val paramMap = List( "eta" -> 0.1f, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala index 4fae9ccd1..064ab0ea7 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala @@ -19,16 +19,17 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.LabeledPoint -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} -import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint} +import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} object DataUtils extends Serializable { - implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint]) + + implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[MLLabeledPoint]) : java.util.Iterator[LabeledPoint] = { fromSparkPointsToXGBoostPoints(sps).asJava } - implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[SparkLabeledPoint]): + implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[MLLabeledPoint]): Iterator[LabeledPoint] = { for (p <- sps) yield { p.features match { @@ -45,6 +46,7 @@ object DataUtils extends Serializable { : java.util.Iterator[LabeledPoint] = { fromSparkVectorToXGBoostPoints(sps).asJava } + implicit def fromSparkVectorToXGBoostPoints(sps: Iterator[Vector]) : Iterator[LabeledPoint] = { for (p <- sps) yield { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 9fe5fd264..a04f10fe9 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -23,26 +23,30 @@ import scala.collection.mutable.ListBuffer import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.commons.logging.LogFactory -import org.apache.hadoop.fs.Path -import org.apache.spark.mllib.linalg.SparseVector -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.hadoop.fs.{FSDataInputStream, Path} +import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset import org.apache.spark.{SparkContext, TaskContext} object XGBoost extends Serializable { private val logger = LogFactory.getLog("XGBoostSpark") - private implicit def convertBoosterToXGBoostModel(booster: Booster) - (implicit sc: SparkContext): XGBoostModel = { - new XGBoostModel(booster) + private def convertBoosterToXGBoostModel(booster: Booster, isClassification: Boolean): + XGBoostModel = { + if (!isClassification) { + new XGBoostRegressionModel(booster) + } else { + new XGBoostClassificationModel(booster) + } } private def fromDenseToSparseLabeledPoints( - denseLabeledPoints: Iterator[LabeledPoint], - missing: Float): Iterator[LabeledPoint] = { + denseLabeledPoints: Iterator[MLLabeledPoint], + missing: Float): Iterator[MLLabeledPoint] = { if (!missing.isNaN) { - val sparseLabeledPoints = new ListBuffer[LabeledPoint] + val sparseLabeledPoints = new ListBuffer[MLLabeledPoint] for (labelPoint <- denseLabeledPoints) { val dVector = labelPoint.features.toDense val indices = new ListBuffer[Int] @@ -55,7 +59,7 @@ object XGBoost extends Serializable { } val sparseVector = new SparseVector(dVector.values.length, indices.toArray, values.toArray) - sparseLabeledPoints += LabeledPoint(labelPoint.label, sparseVector) + sparseLabeledPoints += MLLabeledPoint(labelPoint.label, sparseVector) } sparseLabeledPoints.iterator } else { @@ -64,7 +68,7 @@ object XGBoost extends Serializable { } private[spark] def buildDistributedBoosters( - trainingData: RDD[LabeledPoint], + trainingData: RDD[MLLabeledPoint], xgBoostConfMap: Map[String, Any], rabitEnv: mutable.Map[String, String], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait, @@ -124,20 +128,35 @@ object XGBoost extends Serializable { * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as * true, the user may save the RAM cost for running XGBoost within Spark * @param missing the value represented the missing value in the dataset - * @param inputCol the name of input column, "features" as default value + * @param featureCol the name of input column, "features" as default value * @param labelCol the name of output column, "label" as default value * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed * @return XGBoostModel when successful training */ @throws(classOf[XGBoostError]) - def trainWithDataFrame(trainingData: Dataset[_], - params: Map[String, Any], round: Int, - nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null, - useExternalMemory: Boolean = false, missing: Float = Float.NaN, - inputCol: String = "features", labelCol: String = "label"): XGBoostModel = { + def trainWithDataFrame( + trainingData: Dataset[_], + params: Map[String, Any], + round: Int, + nWorkers: Int, + obj: ObjectiveTrait = null, + eval: EvalTrait = null, + useExternalMemory: Boolean = false, + missing: Float = Float.NaN, + featureCol: String = "features", + labelCol: String = "label"): XGBoostModel = { require(nWorkers > 0, "you must specify more than 0 workers") - new XGBoostEstimator(inputCol, labelCol, params, round, nWorkers, obj, eval, - useExternalMemory, missing).fit(trainingData) + val estimator = new XGBoostEstimator(params, round, nWorkers, obj, eval, + useExternalMemory, missing) + estimator.setFeaturesCol(featureCol).setLabelCol(labelCol).fit(trainingData) + } + + private[spark] def isClassificationTask(objective: Option[Any]): Boolean = { + objective.isDefined && { + val objStr = objective.get.toString + objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" && + objStr != "rank:pairwise") + } } /** @@ -157,9 +176,9 @@ object XGBoost extends Serializable { */ @deprecated(since = "0.7", message = "this method is deprecated since 0.7, users are encouraged" + " to switch to trainWithRDD") - def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int, - nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null, - useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = { + def train(trainingData: RDD[MLLabeledPoint], configMap: Map[String, Any], round: Int, + nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null, + useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = { require(nWorkers > 0, "you must specify more than 0 workers") trainWithRDD(trainingData, configMap, round, nWorkers, obj, eval, useExternalMemory, missing) } @@ -180,10 +199,15 @@ object XGBoost extends Serializable { * @return XGBoostModel when successful training */ @throws(classOf[XGBoostError]) - def trainWithRDD(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int, - nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null, - useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = { + def trainWithRDD(trainingData: RDD[MLLabeledPoint], configMap: Map[String, Any], round: Int, + nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null, + useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = { require(nWorkers > 0, "you must specify more than 0 workers") + if (obj != null) { + require(configMap.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," + + " you have to specify the objective type as classification or regression with a" + + " customized objective function") + } val tracker = new RabitTracker(nWorkers) implicit val sc = trainingData.sparkContext var overridedConfMap = configMap @@ -209,7 +233,13 @@ object XGBoost extends Serializable { val returnVal = tracker.waitFor() logger.info(s"Rabit returns with exit code $returnVal") if (returnVal == 0) { - boosters.first() + convertBoosterToXGBoostModel(boosters.first(), + isClassificationTask( + if (obj == null) { + configMap.get("objective") + } else { + configMap.get("obj_type") + })) } else { try { if (sparkJobThread.isAlive) { @@ -223,6 +253,21 @@ object XGBoost extends Serializable { } } + private def loadGeneralModelParams(inputStream: FSDataInputStream): (String, String, String) = { + val featureCol = inputStream.readUTF() + val labelCol = inputStream.readUTF() + val predictionCol = inputStream.readUTF() + (featureCol, labelCol, predictionCol) + } + + private def setGeneralModelParams( + featureCol: String, labelCol: String, predCol: String, xgBoostModel: XGBoostModel): + XGBoostModel = { + xgBoostModel.setFeaturesCol(featureCol) + xgBoostModel.setLabelCol(labelCol) + xgBoostModel.setPredictionCol(predCol) + } + /** * Load XGBoost model from path in HDFS-compatible file system * @@ -233,7 +278,29 @@ object XGBoost extends Serializable { XGBoostModel = { val path = new Path(modelPath) val dataInStream = path.getFileSystem(sparkContext.hadoopConfiguration).open(path) - val xgBoostModel = new XGBoostModel(SXGBoost.loadModel(dataInStream)) - xgBoostModel + val modelType = dataInStream.readUTF() + val (featureCol, labelCol, predictionCol) = loadGeneralModelParams(dataInStream) + modelType match { + case "_cls_" => + val rawPredictionCol = dataInStream.readUTF() + val thresholdLength = dataInStream.readInt() + var thresholds: Array[Double] = null + if (thresholdLength != -1) { + thresholds = new Array[Double](thresholdLength) + for (i <- 0 until thresholdLength) { + thresholds(i) = dataInStream.readDouble() + } + } + val xgBoostModel = new XGBoostClassificationModel(SXGBoost.loadModel(dataInStream)) + setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel). + asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(rawPredictionCol) + if (thresholdLength != -1) { + xgBoostModel.setThresholds(thresholds) + } + xgBoostModel + case "_reg_" => + val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream)) + setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel) + } } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassificationModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassificationModel.scala new file mode 100644 index 000000000..d0ac2a75e --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassificationModel.scala @@ -0,0 +1,153 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} +import org.apache.spark.ml.linalg.{Vector => MLVector, DenseVector => MLDenseVector} +import org.apache.spark.ml.param.{DoubleArrayParam, Param, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Dataset, Row} + +class XGBoostClassificationModel private[spark]( + override val uid: String, _booster: Booster) + extends XGBoostModel(_booster) { + + def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), _booster) + + // scalastyle:off + + final val outputMargin: Param[Boolean] = new Param[Boolean](this, "outputMargin", "whether to output untransformed margin value ") + + setDefault(outputMargin, false) + + def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel] + + final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).") + + setDefault(rawPredictionCol, "probabilities") + + final def getRawPredictionCol: String = $(rawPredictionCol) + + def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel] + + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0)) + + def getThresholds: Array[Double] = $(thresholds) + + def setThresholds(value: Array[Double]): XGBoostClassificationModel = + set(thresholds, value).asInstanceOf[XGBoostClassificationModel] + + // scalastyle:on + + private def predictRaw( + testSet: Dataset[_], + temporalColName: Option[String] = None, + forceTransformedScore: Option[Boolean] = None): DataFrame = { + val predictRDD = produceRowRDD(testSet, forceTransformedScore.getOrElse($(outputMargin))) + testSet.sparkSession.createDataFrame(predictRDD, schema = { + StructType(testSet.schema.add(StructField( + temporalColName.getOrElse($(rawPredictionCol)), + ArrayType(FloatType, containsNull = false), nullable = false))) + }) + } + + private def fromFeatureToPrediction(testSet: Dataset[_]): Dataset[_] = { + val rawPredictionDF = predictRaw(testSet, Some("rawPredictionCol")) + val predictionUDF = udf(raw2prediction _).apply(col("rawPredictionCol")) + val tempDF = rawPredictionDF.withColumn($(predictionCol), predictionUDF) + val allColumnNames = testSet.columns ++ Seq($(predictionCol)) + tempDF.select(allColumnNames(0), allColumnNames.tail: _*) + } + + private def argMax(vector: mutable.WrappedArray[Float]): Double = { + vector.zipWithIndex.maxBy(_._1)._2 + } + + private def raw2prediction(rawPrediction: mutable.WrappedArray[Float]): Double = { + if (!isDefined(thresholds)) { + argMax(rawPrediction) + } else { + probability2prediction(rawPrediction) + } + } + + private def probability2prediction(probability: mutable.WrappedArray[Float]): Double = { + if (!isDefined(thresholds)) { + argMax(probability) + } else { + val thresholds: Array[Double] = getThresholds + val scaledProbability: mutable.WrappedArray[Double] = + probability.zip(thresholds).map { case (p, t) => + if (t == 0.0) Double.PositiveInfinity else p / t + } + argMax(scaledProbability.map(_.toFloat)) + } + } + + override protected def transformImpl(testSet: Dataset[_]): DataFrame = { + transformSchema(testSet.schema, logging = true) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".transform() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + if ($(outputMargin)) { + setRawPredictionCol("margin") + } + var outputData = testSet + var numColsOutput = 0 + if ($(rawPredictionCol).nonEmpty) { + outputData = predictRaw(testSet) + numColsOutput += 1 + } + + if ($(predictionCol).nonEmpty) { + if ($(rawPredictionCol).nonEmpty) { + require(!$(outputMargin), "XGBoost does not support output final prediction with" + + " untransformed margin. Please set predictionCol as \"\" when setting outputMargin as" + + " true") + val rawToPredUDF = udf(raw2prediction _).apply(col($(rawPredictionCol))) + outputData = outputData.withColumn($(predictionCol), rawToPredUDF) + } else { + outputData = fromFeatureToPrediction(testSet) + } + numColsOutput += 1 + } + + if (numColsOutput == 0) { + this.logWarning(s"$uid: XGBoostClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData.toDF() + } + + private[spark] var numOfClasses = 2 + + def numClasses: Int = numOfClasses + + override def copy(extra: ParamMap): XGBoostClassificationModel = { + defaultCopy(extra) + } + + override protected def predict(features: MLVector): Double = { + throw new Exception("XGBoost does not support online prediction ") + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 64ee91f8b..67d3fa4c5 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -17,20 +17,18 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} -import org.apache.spark.ml.{Predictor, Estimator} +import org.apache.spark.ml.Predictor +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector => MLVector, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{VectorUDT, Vector} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{NumericType, DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row} +import org.apache.spark.sql.types.{StructType, DoubleType} +import org.apache.spark.sql.{Dataset, Row} /** * the estimator wrapping XGBoost to produce a training model * - * @param inputCol the name of input column - * @param labelCol the name of label column * @param xgboostParams the parameters configuring XGBoost * @param round the number of iterations to train * @param nWorkers the total number of workers of xgboost @@ -39,43 +37,47 @@ import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row} * @param useExternalMemory whether to use external memory when training * @param missing the value taken as missing */ -class XGBoostEstimator( - inputCol: String, labelCol: String, - xgboostParams: Map[String, Any], round: Int, nWorkers: Int, - obj: ObjectiveTrait = null, - eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN) - extends Estimator[XGBoostModel] { - - override val uid: String = Identifiable.randomUID("XGBoostEstimator") +class XGBoostEstimator private[spark]( + override val uid: String, xgboostParams: Map[String, Any], round: Int, nWorkers: Int, + obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean, missing: Float) + extends Predictor[MLVector, XGBoostEstimator, XGBoostModel] { + def this(xgboostParams: Map[String, Any], round: Int, nWorkers: Int, + obj: ObjectiveTrait = null, + eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN) = + this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any], round: Int, + nWorkers: Int, obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean, + missing: Float) /** * produce a XGBoostModel by fitting the given dataset */ - def fit(trainingSet: Dataset[_]): XGBoostModel = { + override def train(trainingSet: Dataset[_]): XGBoostModel = { val instances = trainingSet.select( - col(inputCol), col(labelCol).cast(DoubleType)).rdd.map { - case Row(feature: Vector, label: Double) => + col($(featuresCol)), col($(labelCol)).cast(DoubleType)).rdd.map { + case Row(feature: MLVector, label: Double) => LabeledPoint(label, feature) } transformSchema(trainingSet.schema, logging = true) val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, round, nWorkers, obj, eval, useExternalMemory, missing).setParent(this) - copyValues(trainedModel) + val returnedModel = copyValues(trainedModel) + if (XGBoost.isClassificationTask( + if (obj == null) xgboostParams.get("objective") else xgboostParams.get("obj_type"))) { + val numClass = { + if (xgboostParams.contains("num_class")) { + xgboostParams("num_class").asInstanceOf[Int] + } + else { + 2 + } + } + returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClass + } + returnedModel } - override def copy(extra: ParamMap): Estimator[XGBoostModel] = { + override def copy(extra: ParamMap): XGBoostEstimator = { defaultCopy(extra) } - - override def transformSchema(schema: StructType): StructType = { - // check input type, for now we only support vectorUDT as the input feature type - val inputType = schema(inputCol).dataType - require(inputType.equals(new VectorUDT), s"the type of input column $inputCol has to VectorUDT") - // check label Type, - val labelType = schema(labelCol).dataType - require(labelType.isInstanceOf[NumericType], s"the type of label column $labelCol has to" + - s" be NumericType") - schema - } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index eb81e0c22..da50309db 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -20,24 +20,48 @@ import scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait} -import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{Model, PredictionModel} -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{VectorUDT, DenseVector, Vector} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.hadoop.fs.{FSDataOutputStream, Path} +import org.apache.spark.ml.PredictionModel +import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} +import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector} +import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql._ +import org.apache.spark.sql.types.{FloatType, ArrayType, DataType} import org.apache.spark.{SparkContext, TaskContext} -class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializable { +abstract class XGBoostModel(_booster: Booster) + extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params { - var inputCol = "features" - var outputCol = "prediction" - var outputType: DataType = ArrayType(elementType = FloatType, containsNull = false) + def setLabelCol(name: String): XGBoostModel = set(labelCol, name) + + // scalastyle:off + + final val useExternalMemory: Param[Boolean] = new Param[Boolean](this, "useExternalMemory", "whether to use external memory for prediction") + + setDefault(useExternalMemory, false) + + def setExternalMemory(value: Boolean): XGBoostModel = set(useExternalMemory, value) + + // scalastyle:on + + /** + * Predict leaf instances with the given test set (represented as RDD) + * + * @param testSet test set represented as RDD + */ + def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Array[Float]]] = { + import DataUtils._ + val broadcastBooster = testSet.sparkContext.broadcast(_booster) + testSet.mapPartitions { testSamples => + if (testSamples.hasNext) { + val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) + Iterator(broadcastBooster.value.predictLeaf(dMatrix)) + } else { + Iterator() + } + } + } /** * evaluate XGBoostModel with a RDD-wrapped dataset @@ -53,24 +77,25 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa * @param useExternalCache if use external cache * @return the average metric over all partitions */ - def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null, + def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null, iter: Int = -1, useExternalCache: Boolean = false): String = { - require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter") + require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter") val broadcastBooster = evalDataset.sparkContext.broadcast(_booster) val appName = evalDataset.context.appName val allEvalMetrics = evalDataset.mapPartitions { labeledPointsPartition => if (labeledPointsPartition.hasNext) { - val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) Rabit.init(rabitEnv.asJava) - import DataUtils._ val cacheFileName = { if (useExternalCache) { - s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}" + s"$appName-${TaskContext.get().stageId()}-$evalName" + + s"-deval_cache-${TaskContext.getPartitionId()}" } else { null } } + import DataUtils._ val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName) if (iter == -1) { val predictions = broadcastBooster.value.predict(dMatrix) @@ -91,18 +116,48 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa s"$evalPrefix = $evalMetricMean" } + /** + * Predict result with the given test set (represented as RDD) + * + * @param testSet test set represented as RDD + * @param missingValue the specified value to represent the missing value + */ + def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Array[Float]]] = { + val broadcastBooster = testSet.sparkContext.broadcast(_booster) + testSet.mapPartitions { testSamples => + val sampleArray = testSamples.toList + val numRows = sampleArray.size + val numColumns = sampleArray.head.size + if (numRows == 0) { + Iterator() + } else { + val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) + Rabit.init(rabitEnv.asJava) + // translate to required format + val flatSampleArray = new Array[Float](numRows * numColumns) + for (i <- flatSampleArray.indices) { + flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat + } + val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) + Rabit.shutdown() + Iterator(broadcastBooster.value.predict(dMatrix)) + } + } + } + /** * Predict result with the given test set (represented as RDD) * * @param testSet test set represented as RDD * @param useExternalCache whether to use external cache for the test set */ - def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = { - import DataUtils._ + def predict(testSet: RDD[MLVector], useExternalCache: Boolean = false): + RDD[Array[Array[Float]]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) val appName = testSet.context.appName testSet.mapPartitions { testSamples => if (testSamples.hasNext) { + import DataUtils._ val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap Rabit.init(rabitEnv.asJava) val cacheFileName = { @@ -122,48 +177,76 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa } } + protected def transformImpl(testSet: Dataset[_]): DataFrame + /** - * Predict result with the given test set (represented as RDD) + * append leaf index of each row as an additional column in the original dataset * - * @param testSet test set represented as RDD - * @param missingValue the specified value to represent the missing value + * @return the original dataframe with an additional column containing prediction results */ - def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = { - val broadcastBooster = testSet.sparkContext.broadcast(_booster) - testSet.mapPartitions { testSamples => - val sampleArray = testSamples.toList - val numRows = sampleArray.size - val numColumns = sampleArray.head.size - if (numRows == 0) { - Iterator() - } else { - // translate to required format - val flatSampleArray = new Array[Float](numRows * numColumns) - for (i <- flatSampleArray.indices) { - flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat + def transformLeaf(testSet: Dataset[_]): DataFrame = { + val predictRDD = produceRowRDD(testSet, predLeaf = true) + setPredictionCol("predLeaf") + transformSchema(testSet.schema, logging = true) + testSet.sparkSession.createDataFrame(predictRDD, testSet.schema.add($(predictionCol), + ArrayType(FloatType, containsNull = false))) + } + + protected def produceRowRDD(testSet: Dataset[_], outputMargin: Boolean = false, + predLeaf: Boolean = false): RDD[Row] = { + val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster) + val appName = testSet.sparkSession.sparkContext.appName + testSet.rdd.mapPartitions { + rowIterator => + if (rowIterator.hasNext) { + val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap + Rabit.init(rabitEnv.asJava) + val (rowItr1, rowItr2) = rowIterator.duplicate + val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[MLVector]( + $(featuresCol))).toList.iterator + import DataUtils._ + val cachePrefix = { + if ($(useExternalMemory)) { + s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}" + } else { + null + } + } + val testDataset = new DMatrix(vectorIterator, cachePrefix) + val rawPredictResults = { + if (!predLeaf) { + broadcastBooster.value.predict(testDataset, outputMargin). + map(Row(_)).iterator + } else { + broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator + } + } + Rabit.shutdown() + // concatenate original data partition and predictions + rowItr1.zip(rawPredictResults).map { + case (originalColumns: Row, predictColumn: Row) => + Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq) + } + } else { + Iterator[Row]() } - val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) - Iterator(broadcastBooster.value.predict(dMatrix)) - } } } /** - * Predict leaf instances with the given test set (represented as RDD) + * produces the prediction results and append as an additional column in the original dataset + * NOTE: the prediction results is kept as the original format of xgboost * - * @param testSet test set represented as RDD + * @return the original dataframe with an additional column containing prediction results */ - def predictLeaves(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = { - import DataUtils._ - val broadcastBooster = testSet.sparkContext.broadcast(_booster) - testSet.mapPartitions { testSamples => - if (testSamples.hasNext) { - val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) - Iterator(broadcastBooster.value.predictLeaf(dMatrix)) - } else { - Iterator() - } - } + override def transform(testSet: Dataset[_]): DataFrame = { + transformImpl(testSet) + } + + private def saveGeneralModelParam(outputStream: FSDataOutputStream): Unit = { + outputStream.writeUTF(getFeaturesCol) + outputStream.writeUTF(getLabelCol) + outputStream.writeUTF(getPredictionCol) } /** @@ -174,109 +257,34 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = { val path = new Path(modelPath) val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path) + // output model type + this match { + case model: XGBoostClassificationModel => + outputStream.writeUTF("_cls_") + saveGeneralModelParam(outputStream) + outputStream.writeUTF(model.getRawPredictionCol) + // threshold + // threshold length + if (!isDefined(model.thresholds)) { + outputStream.writeInt(-1) + } else { + val thresholdLength = model.getThresholds.length + outputStream.writeInt(thresholdLength) + for (i <- 0 until thresholdLength) { + outputStream.writeDouble(model.getThresholds(i)) + } + } + case model: XGBoostRegressionModel => + outputStream.writeUTF("_reg_") + // eventual prediction col + saveGeneralModelParam(outputStream) + } + // booster _booster.saveModel(outputStream) outputStream.close() } + // override protected def featuresDataType: DataType = new VectorUDT + def booster: Booster = _booster - - override val uid: String = Identifiable.randomUID("XGBoostModel") - - override def copy(extra: ParamMap): XGBoostModel = { - defaultCopy(extra) - } - - /** - * append leaf index of each row as an additional column in the original dataset - * - * @return the original dataframe with an additional column containing prediction results - */ - def transformLeaf(testSet: Dataset[_]): Unit = { - outputCol = "predLeaf" - transformSchema(testSet.schema, logging = true) - val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster) - val instances = testSet.rdd.mapPartitions { - rowIterator => - if (rowIterator.hasNext) { - val (rowItr1, rowItr2) = rowIterator.duplicate - val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](inputCol)). - toList.iterator - import DataUtils._ - val testDataset = new DMatrix(vectorIterator, null) - val rowPredictResults = broadcastBooster.value.predictLeaf(testDataset) - val predictResults = rowPredictResults.map(prediction => Row(prediction)).iterator - rowItr1.zip(predictResults).map { - case (originalColumns: Row, predictColumn: Row) => - Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq) - } - } else { - Iterator[Row]() - } - } - testSet.sparkSession.createDataFrame(instances, testSet.schema.add(outputCol, outputType)). - cache() - } - - /** - * produces the prediction results and append as an additional column in the original dataset - * NOTE: the prediction results is kept as the original format of xgboost - * - * @return the original dataframe with an additional column containing prediction results - */ - override def transform(testSet: Dataset[_]): DataFrame = { - transform(testSet, None) - } - - /** - * produces the prediction results and append as an additional column in the original dataset - * NOTE: the prediction results is transformed by applying the transformation function - * predictResultTrans to the original xgboost output - * - * @param rawPredictTransformer the function to transform xgboost output to the expected format - * @return the original dataframe with an additional column containing prediction results - */ - def transform(testSet: Dataset[_], rawPredictTransformer: Option[Array[Float] => DataType]): - DataFrame = { - transformSchema(testSet.schema, logging = true) - val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster) - val instances = testSet.rdd.mapPartitions { - rowIterator => - if (rowIterator.hasNext) { - val (rowItr1, rowItr2) = rowIterator.duplicate - val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](inputCol)). - toList.iterator - import DataUtils._ - val testDataset = new DMatrix(vectorIterator, null) - val rowPredictResults = broadcastBooster.value.predict(testDataset) - val predictResults = { - if (rawPredictTransformer.isDefined) { - rowPredictResults.map(prediction => - Row(rawPredictTransformer.get(prediction))).iterator - } else { - rowPredictResults.map(prediction => Row(prediction)).iterator - } - } - rowItr1.zip(predictResults).map { - case (originalColumns: Row, predictColumn: Row) => - Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq) - } - } else { - Iterator[Row]() - } - } - testSet.sparkSession.createDataFrame(instances, testSet.schema.add(outputCol, outputType)). - cache() - } - - @DeveloperApi - override def transformSchema(schema: StructType): StructType = { - if (schema.fieldNames.contains(outputCol)) { - throw new IllegalArgumentException(s"Output column $outputCol already exists.") - } - val inputType = schema(inputCol).dataType - require(inputType.equals(new VectorUDT), - s"the type of input column $inputCol has to be VectorUDT") - val outputFields = schema.fields :+ StructField(outputCol, outputType, nullable = false) - StructType(outputFields) - } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressionModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressionModel.scala new file mode 100644 index 000000000..7e398ea3d --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressionModel.scala @@ -0,0 +1,48 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.scala.Booster +import org.apache.spark.ml.linalg.{Vector => MLVector} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} + +class XGBoostRegressionModel private[spark](override val uid: String, _booster: Booster) + extends XGBoostModel(_booster) { + + def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostRegressionModel"), _booster) + + override protected def transformImpl(testSet: Dataset[_]): DataFrame = { + transformSchema(testSet.schema, logging = true) + val predictRDD = produceRowRDD(testSet) + testSet.sparkSession.createDataFrame(predictRDD, schema = + StructType(testSet.schema.add(StructField($(predictionCol), + ArrayType(FloatType, containsNull = false), nullable = false))) + ) + } + + override protected def predict(features: MLVector): Double = { + throw new Exception("XGBoost does not support online prediction for now") + } + + override def copy(extra: ParamMap): XGBoostRegressionModel = { + defaultCopy(extra) + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala index ec37ec0cd..91a840911 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala @@ -50,6 +50,8 @@ class EvalError extends EvalTrait { logger.error(ex) return -1f } + require(predicts.length == labels.length, s"predicts length ${predicts.length} has to be" + + s" equal with label length ${labels.length}") val nrow: Int = predicts.length for (i <- 0 until nrow) { if (labels(i) == 0.0 && predicts(i)(0) > 0) { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala index 0729cde0d..65b886d2b 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/SharedSparkContext.scala @@ -17,20 +17,21 @@ package ml.dmlc.xgboost4j.scala.spark import org.apache.spark.{SparkConf, SparkContext} -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.{BeforeAndAfterAll, FunSuite} -class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable { +trait SharedSparkContext extends FunSuite with BeforeAndAfterAll with Serializable { @transient protected implicit var sc: SparkContext = null - before { + override def beforeAll() { // build SparkContext - val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite") + val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite"). + set("spark.driver.memory", "512m") sc = new SparkContext(sparkConf) sc.setLogLevel("ERROR") } - after { + override def afterAll() { if (sc != null) { sc.stop() } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala index 83dbb3e1e..56c373e4e 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala @@ -21,17 +21,23 @@ import java.io.File import scala.collection.mutable.ListBuffer import scala.io.Source -import ml.dmlc.xgboost4j.java.XGBoostError -import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait} -import org.apache.commons.logging.LogFactory import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{DenseVector, Vector => SparkVector} import org.apache.spark.rdd.RDD trait Utils extends Serializable { protected val numWorkers = Runtime.getRuntime().availableProcessors() + protected var labeledPointsRDD: RDD[LabeledPoint] = null + + protected def cleanExternalCache(prefix: String): Unit = { + val dir = new File(".") + for (file <- dir.listFiles() if file.getName.startsWith(prefix)) { + file.delete() + } + } + protected def loadLabelPoints(filePath: String): List[LabeledPoint] = { val file = Source.fromFile(new File(filePath)) val sampleList = new ListBuffer[LabeledPoint] @@ -41,6 +47,15 @@ trait Utils extends Serializable { sampleList.toList } + protected def loadLabelAndVector(filePath: String): List[(Double, SparkVector)] = { + val file = Source.fromFile(new File(filePath)) + val sampleList = new ListBuffer[(Double, SparkVector)] + for (sample <- file.getLines()) { + sampleList += fromSVMStringToLabelAndVector(sample) + } + sampleList.toList + } + protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = { val labelAndFeatures = line.split(" ") val label = labelAndFeatures(0).toDouble @@ -59,7 +74,10 @@ trait Utils extends Serializable { } protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = { - val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile) - sparkContext.parallelize(sampleList, numWorkers) + if (labeledPointsRDD == null) { + val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile) + labeledPointsRDD = sparkContext.parallelize(sampleList, numWorkers) + } + labeledPointsRDD } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala new file mode 100644 index 000000000..c45d1f1f1 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala @@ -0,0 +1,60 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} +import org.apache.spark.{SparkConf, SparkContext} +import org.scalatest.FunSuite + +class XGBoostConfigureSuite extends FunSuite with Utils { + + test("nthread configuration must be equal to spark.task.cpus") { + val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite"). + set("spark.task.cpus", "4") + val customSparkContext = new SparkContext(sparkConf) + customSparkContext.setLogLevel("ERROR") + // start another app + val trainingRDD = buildTrainingRDD(customSparkContext) + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic", "nthread" -> 6) + intercept[IllegalArgumentException] { + XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) + } + customSparkContext.stop() + } + + test("kryoSerializer test") { + labeledPointsRDD = null + val eval = new EvalError() + val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sparkConf.registerKryoClasses(Array(classOf[Booster])) + val customSparkContext = new SparkContext(sparkConf) + customSparkContext.setLogLevel("ERROR") + val trainingRDD = buildTrainingRDD(customSparkContext) + val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator + import DataUtils._ + val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic") + val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) + assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), + testSetDMatrix) < 0.1) + customSparkContext.stop() + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala index 48b450e60..284f99a22 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostDFSuite.scala @@ -25,77 +25,27 @@ import scala.io.Source import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.VectorUDT -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.sql._ -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType} class XGBoostDFSuite extends SharedSparkContext with Utils { - private def loadRow(filePath: String): List[Row] = { - val file = Source.fromFile(new File(filePath)) - val rowList = new ListBuffer[Row] - for (rowLine <- file.getLines()) { - rowList += fromSVMStringToRow(rowLine) + private var trainingDF: DataFrame = null + + private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None): DataFrame = { + if (trainingDF == null) { + val rowList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile) + val labeledPointsRDD = sparkContext.getOrElse(sc).parallelize(rowList, numWorkers) + val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate() + import sparkSession.implicits._ + trainingDF = sparkSession.createDataset(labeledPointsRDD).toDF } - rowList.toList + trainingDF } - private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None): - DataFrame = { - val rowList = loadRow(getClass.getResource("/agaricus.txt.train").getFile) - val rowRDD = sparkContext.getOrElse(sc).parallelize(rowList, numWorkers) - val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate() - sparkSession.createDataFrame(rowRDD, - StructType(Array(StructField("label", DoubleType, nullable = false), - StructField("features", new VectorUDT, nullable = false)))) - } - - private def fromSVMStringToRow(line: String): Row = { - val (label, sv) = fromSVMStringToLabelAndVector(line) - Row(label, sv) - } - - test("test consistency between training with dataframe and RDD") { - val trainingDF = buildTrainingDataframe() - val trainingRDD = buildTrainingRDD(sc) - val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0", - "objective" -> "binary:logistic").toMap - val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = 5, nWorkers = numWorkers, useExternalMemory = false) - val xgBoostModelWithRDD = XGBoost.trainWithRDD(trainingRDD, paramMap, - round = 5, nWorkers = numWorkers, useExternalMemory = false) - val eval = new EvalError() - val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator - import DataUtils._ - val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) - assert( - eval.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) === - eval.eval(xgBoostModelWithRDD.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix)) - } - - test("test transform of dataframe-based model") { - val trainingDF = buildTrainingDataframe() - val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0", - "objective" -> "binary:logistic").toMap - val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, - round = 5, nWorkers = numWorkers, useExternalMemory = false) - val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile) - val testRowsRDD = sc.parallelize(testSet.zipWithIndex, numWorkers).map{ - case (instance: LabeledPoint, id: Int) => - Row(id, instance.features, instance.label) - } - val testDF = trainingDF.sparkSession.createDataFrame(testRowsRDD, StructType( - Array(StructField("id", IntegerType), - StructField("features", new VectorUDT), StructField("label", DoubleType)))) - xgBoostModelWithDF.transform(testDF).show() - } - - test("test order preservation of dataframe-based model") { - val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0", - "objective" -> "binary:logistic").toMap + test("test consistency and order preservation of dataframe-based model") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic") val trainingItr = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile). iterator val (testItr, auxTestItr) = @@ -105,25 +55,109 @@ class XGBoostDFSuite extends SharedSparkContext with Utils { val testDMatrix = new DMatrix(new JDMatrix(testItr, null)) val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5) val predResultFromSeq = xgboostModel.predict(testDMatrix) - val testRowsRDD = sc.parallelize( - auxTestItr.toList.zipWithIndex, numWorkers).map { + val testSetItr = auxTestItr.zipWithIndex.map { case (instance: LabeledPoint, id: Int) => - Row(id, instance.features, instance.label) + (id, instance.features, instance.label) } val trainingDF = buildTrainingDataframe() val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers, useExternalMemory = false) - val testDF = trainingDF.sqlContext.createDataFrame(testRowsRDD, StructType( - Array(StructField("id", IntegerType), StructField("features", new VectorUDT), - StructField("label", DoubleType)))) - val predResultsFromDF = - xgBoostModelWithDF.transform(testDF).collect().map(row => (row.getAs[Int]("id"), - row.getAs[mutable.WrappedArray[Float]]("prediction"))).toMap + val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF( + "id", "features", "label") + val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF). + collect().map(row => + (row.getAs[Int]("id"), row.getAs[mutable.WrappedArray[Float]]("probabilities")) + ).toMap + assert(testDF.count() === predResultsFromDF.size) for (i <- predResultFromSeq.indices) { assert(predResultFromSeq(i).length === predResultsFromDF(i).length) for (j <- predResultFromSeq(i).indices) { assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j)) } } + cleanExternalCache("XGBoostDFSuite") + } + + test("test transformLeaf") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic") + val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator + val trainingDF = buildTrainingDataframe() + val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, + round = 5, nWorkers = numWorkers, useExternalMemory = false) + val testSetItr = testItr.zipWithIndex.map { + case (instance: LabeledPoint, id: Int) => + (id, instance.features, instance.label) + } + val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF( + "id", "features", "label") + xgBoostModelWithDF.transformLeaf(testDF).show() + } + + test("test schema of XGBoostRegressionModel") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear") + val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile).iterator. + zipWithIndex.map { case (instance: LabeledPoint, id: Int) => + (id, instance.features, instance.label) + } + val trainingDF = { + val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile) + val labeledPointsRDD = sc.parallelize(rowList, numWorkers) + val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate() + import sparkSession.implicits._ + sparkSession.createDataset(labeledPointsRDD).toDF + } + val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, + round = 5, nWorkers = numWorkers, useExternalMemory = true) + xgBoostModelWithDF.setPredictionCol("final_prediction") + val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF( + "id", "features", "label") + val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF) + assert(predictionDF.columns.contains("id") === true) + assert(predictionDF.columns.contains("features") === true) + assert(predictionDF.columns.contains("label") === true) + assert(predictionDF.columns.contains("final_prediction") === true) + predictionDF.show() + cleanExternalCache("XGBoostDFSuite") + } + + test("test schema of XGBoostClassificationModel") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic") + val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator. + zipWithIndex.map { case (instance: LabeledPoint, id: Int) => + (id, instance.features, instance.label) + } + val trainingDF = buildTrainingDataframe() + val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap, + round = 5, nWorkers = numWorkers, useExternalMemory = true) + xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol( + "raw_prediction").setPredictionCol("final_prediction") + val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF( + "id", "features", "label") + var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF) + assert(predictionDF.columns.contains("id") === true) + assert(predictionDF.columns.contains("features") === true) + assert(predictionDF.columns.contains("label") === true) + assert(predictionDF.columns.contains("raw_prediction") === true) + assert(predictionDF.columns.contains("final_prediction") === true) + xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(""). + setPredictionCol("final_prediction") + predictionDF = xgBoostModelWithDF.transform(testDF) + assert(predictionDF.columns.contains("id") === true) + assert(predictionDF.columns.contains("features") === true) + assert(predictionDF.columns.contains("label") === true) + assert(predictionDF.columns.contains("raw_prediction") === false) + assert(predictionDF.columns.contains("final_prediction") === true) + xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel]. + setRawPredictionCol("raw_prediction").setPredictionCol("") + predictionDF = xgBoostModelWithDF.transform(testDF) + assert(predictionDF.columns.contains("id") === true) + assert(predictionDF.columns.contains("features") === true) + assert(predictionDF.columns.contains("label") === true) + assert(predictionDF.columns.contains("raw_prediction") === true) + assert(predictionDF.columns.contains("final_prediction") === false) + cleanExternalCache("XGBoostDFSuite") } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 956bf859c..dff4bc9d9 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -16,66 +16,47 @@ package ml.dmlc.xgboost4j.scala.spark -import java.io.File import java.nio.file.Files import scala.collection.mutable.ListBuffer import scala.util.Random -import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix} -import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => ScalaXGBoost} -import org.apache.spark.mllib.linalg.{Vector => SparkVector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.scala.DMatrix +import org.apache.spark.SparkContext +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext} class XGBoostGeneralSuite extends SharedSparkContext with Utils { test("build RDD containing boosters with the specified worker number") { val trainingRDD = buildTrainingRDD(sc) - val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator - import DataUtils._ - val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) val boosterRDD = XGBoost.buildDistributedBoosters( trainingRDD, - List("eta" -> "1", "max_depth" -> "6", "silent" -> "0", + List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic").toMap, new scala.collection.mutable.HashMap[String, String], - numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = false) + numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true) val boosterCount = boosterRDD.count() assert(boosterCount === 2) - val boosters = boosterRDD.collect() - val eval = new EvalError() - for (booster <- boosters) { - // the threshold is 0.11 because it does not sync boosters with AllReduce - val predicts = booster.predict(testSetDMatrix, outPutMargin = true) - assert(eval.eval(predicts, testSetDMatrix) < 0.11) - } + cleanExternalCache("XGBoostSuite") } test("training with external memory cache") { - sc.stop() - sc = null - val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite") - val customSparkContext = new SparkContext(sparkConf) - customSparkContext.setLogLevel("ERROR") val eval = new EvalError() - val trainingRDD = buildTrainingRDD(customSparkContext) + val trainingRDD = buildTrainingRDD(sc) val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator import DataUtils._ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) - val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0", + val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic").toMap val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers, useExternalMemory = true) assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix) < 0.1) - customSparkContext.stop() // clean - val dir = new File(".") - for (file <- dir.listFiles() if file.getName.startsWith("XGBoostSuite-0-dtrain_cache")) { - file.delete() - } + cleanExternalCache("XGBoostSuite") } test("test with dense vectors containing missing value") { @@ -106,10 +87,13 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { } val trainingRDD = buildDenseRDD().repartition(4) val testRDD = buildDenseRDD().repartition(4) - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", "objective" -> "binary:logistic").toMap - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) + val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers, + useExternalMemory = true) xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect() + // clean + cleanExternalCache("XGBoostSuite") } test("test consistency of prediction functions with RDD") { @@ -120,11 +104,12 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { for (i <- testSet.indices) { assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values)) } - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", - "objective" -> "binary:logistic").toMap + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic") val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) val predRDD = xgBoostModel.predict(testRDD) val predResult1 = predRDD.collect()(0) + assert(testRDD.count() === predResult1.length) import DataUtils._ val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator)) for (i <- predResult1.indices; j <- predResult1(i).indices) { @@ -134,9 +119,9 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { test("test eval functions with RDD") { val trainingRDD = buildTrainingRDD(sc).cache() - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", - "objective" -> "binary:logistic").toMap - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic") + val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers) xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false) xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false) } @@ -150,7 +135,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val testRDD = buildEmptyRDD() val tempDir = Files.createTempDirectory("xgboosttest-") val tempFile = Files.createTempFile(tempDir, "", "") - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", "objective" -> "binary:logistic").toMap val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) println(xgBoostModel.predict(testRDD).collect().length === 0) @@ -164,8 +149,8 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) val tempDir = Files.createTempDirectory("xgboosttest-") val tempFile = Files.createTempFile(tempDir, "", "") - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", - "objective" -> "binary:logistic").toMap + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic") val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix) @@ -177,41 +162,40 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { assert(loadedEvalResults == evalResults) } - test("nthread configuration must be equal to spark.task.cpus") { - sc.stop() - sc = null - val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite"). - set("spark.task.cpus", "4") - val customSparkContext = new SparkContext(sparkConf) - customSparkContext.setLogLevel("ERROR") - // start another app - val trainingRDD = buildTrainingRDD(customSparkContext) - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", - "objective" -> "binary:logistic", "nthread" -> 6).toMap - intercept[IllegalArgumentException] { - XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) - } - customSparkContext.stop() - } - - test("kryoSerializer test") { - sc.stop() - sc = null - val eval = new EvalError() - val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite") - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sparkConf.registerKryoClasses(Array(classOf[Booster])) - val customSparkContext = new SparkContext(sparkConf) - customSparkContext.setLogLevel("ERROR") - val trainingRDD = buildTrainingRDD(customSparkContext) - val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator - import DataUtils._ - val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null)) - val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0", - "objective" -> "binary:logistic").toMap - val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) - assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true), - testSetDMatrix) < 0.1) - customSparkContext.stop() + test("test save and load of different types of models") { + val tempDir = Files.createTempDirectory("xgboosttest-") + val tempFile = Files.createTempFile(tempDir, "", "") + val trainingRDD = buildTrainingRDD(sc) + var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear") + // validate regression model + var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, + nWorkers = numWorkers, useExternalMemory = false) + xgBoostModel.setFeaturesCol("feature_col") + xgBoostModel.setLabelCol("label_col") + xgBoostModel.setPredictionCol("prediction_col") + xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) + var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) + assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel]) + assert(loadedXGBoostModel.getFeaturesCol == "feature_col") + assert(loadedXGBoostModel.getLabelCol == "label_col") + assert(loadedXGBoostModel.getPredictionCol == "prediction_col") + // classification model + paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic") + xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, + nWorkers = numWorkers, useExternalMemory = false) + xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col") + xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5)) + xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath) + loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath) + assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel]) + assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol == + "raw_col") + assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep == + Array(0.5, 0.5).deep) + assert(loadedXGBoostModel.getFeaturesCol == "features") + assert(loadedXGBoostModel.getLabelCol == "label") + assert(loadedXGBoostModel.getPredictionCol == "prediction") } } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala index 388005edc..587ace352 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala @@ -38,6 +38,8 @@ trait EvalTrait extends IEvaluation { def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float private[scala] def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = { + require(predicts.length == jdmat.getLabel.length, "predicts size and label size must match " + + s" predicts size: ${predicts.length}, label size: ${jdmat.getLabel.length}") eval(predicts, new DMatrix(jdmat)) } }