[jvm-packages] separate classification and regression model and integrate with ML package (#1608)

This commit is contained in:
Nan Zhu 2016-09-30 11:49:03 -04:00 committed by GitHub
parent 3b9987ca9c
commit 1673bcbe7e
16 changed files with 771 additions and 381 deletions

View File

@ -49,7 +49,7 @@ addons:
before_install:
- source dmlc-core/scripts/travis/travis_setup_env.sh
- export PYTHONPATH=${PYTHONPATH}:${PWD}/python-package
- echo "MAVEN_OPTS='-Xmx2048m -XX:MaxPermSize=1024m -XX:ReservedCodeCacheSize=512m'" > ~/.mavenrc
- echo "MAVEN_OPTS='-Xmx2048m -XX:MaxPermSize=1024m -XX:ReservedCodeCacheSize=512m -Dorg.slf4j.simpleLogger.defaultLogLevel=error'" > ~/.mavenrc
install:
- source tests/travis/setup.sh

View File

@ -31,6 +31,11 @@ mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py
# copy test data files
mkdir -p xgboost4j-spark/src/test/resources/
cd ../demo/regression
python mapfeat.py
python mknfold.py machine.txt 1
cd -
cp ../demo/regression/machine.txt.t* xgboost4j-spark/src/test/resources/
cp ../demo/data/agaricus.* xgboost4j-spark/src/test/resources/
popd > /dev/null
echo "complete"

View File

@ -20,6 +20,8 @@ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector}
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
object SparkWithRDD {
def main(args: Array[String]): Unit = {
@ -38,8 +40,10 @@ object SparkWithRDD {
// number of iterations
val numRound = args(0).toInt
import DataUtils._
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath).collect().iterator
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).map(lp =>
MLLabeledPoint(lp.label, new MLDenseVector(lp.features.toArray)))
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath).collect().map(
lp => new MLDenseVector(lp.features.toArray)).iterator
// training parameters
val paramMap = List(
"eta" -> 0.1f,

View File

@ -19,16 +19,17 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.LabeledPoint
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
object DataUtils extends Serializable {
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[MLLabeledPoint])
: java.util.Iterator[LabeledPoint] = {
fromSparkPointsToXGBoostPoints(sps).asJava
}
implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[SparkLabeledPoint]):
implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[MLLabeledPoint]):
Iterator[LabeledPoint] = {
for (p <- sps) yield {
p.features match {
@ -45,6 +46,7 @@ object DataUtils extends Serializable {
: java.util.Iterator[LabeledPoint] = {
fromSparkVectorToXGBoostPoints(sps).asJava
}
implicit def fromSparkVectorToXGBoostPoints(sps: Iterator[Vector])
: Iterator[LabeledPoint] = {
for (p <- sps) yield {

View File

@ -23,26 +23,30 @@ import scala.collection.mutable.ListBuffer
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.Path
import org.apache.spark.mllib.linalg.SparseVector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.Dataset
import org.apache.spark.{SparkContext, TaskContext}
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private implicit def convertBoosterToXGBoostModel(booster: Booster)
(implicit sc: SparkContext): XGBoostModel = {
new XGBoostModel(booster)
private def convertBoosterToXGBoostModel(booster: Booster, isClassification: Boolean):
XGBoostModel = {
if (!isClassification) {
new XGBoostRegressionModel(booster)
} else {
new XGBoostClassificationModel(booster)
}
}
private def fromDenseToSparseLabeledPoints(
denseLabeledPoints: Iterator[LabeledPoint],
missing: Float): Iterator[LabeledPoint] = {
denseLabeledPoints: Iterator[MLLabeledPoint],
missing: Float): Iterator[MLLabeledPoint] = {
if (!missing.isNaN) {
val sparseLabeledPoints = new ListBuffer[LabeledPoint]
val sparseLabeledPoints = new ListBuffer[MLLabeledPoint]
for (labelPoint <- denseLabeledPoints) {
val dVector = labelPoint.features.toDense
val indices = new ListBuffer[Int]
@ -55,7 +59,7 @@ object XGBoost extends Serializable {
}
val sparseVector = new SparseVector(dVector.values.length, indices.toArray,
values.toArray)
sparseLabeledPoints += LabeledPoint(labelPoint.label, sparseVector)
sparseLabeledPoints += MLLabeledPoint(labelPoint.label, sparseVector)
}
sparseLabeledPoints.iterator
} else {
@ -64,7 +68,7 @@ object XGBoost extends Serializable {
}
private[spark] def buildDistributedBoosters(
trainingData: RDD[LabeledPoint],
trainingData: RDD[MLLabeledPoint],
xgBoostConfMap: Map[String, Any],
rabitEnv: mutable.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
@ -124,20 +128,35 @@ object XGBoost extends Serializable {
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @param inputCol the name of input column, "features" as default value
* @param featureCol the name of input column, "features" as default value
* @param labelCol the name of output column, "label" as default value
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def trainWithDataFrame(trainingData: Dataset[_],
params: Map[String, Any], round: Int,
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
useExternalMemory: Boolean = false, missing: Float = Float.NaN,
inputCol: String = "features", labelCol: String = "label"): XGBoostModel = {
def trainWithDataFrame(
trainingData: Dataset[_],
params: Map[String, Any],
round: Int,
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN,
featureCol: String = "features",
labelCol: String = "label"): XGBoostModel = {
require(nWorkers > 0, "you must specify more than 0 workers")
new XGBoostEstimator(inputCol, labelCol, params, round, nWorkers, obj, eval,
useExternalMemory, missing).fit(trainingData)
val estimator = new XGBoostEstimator(params, round, nWorkers, obj, eval,
useExternalMemory, missing)
estimator.setFeaturesCol(featureCol).setLabelCol(labelCol).fit(trainingData)
}
private[spark] def isClassificationTask(objective: Option[Any]): Boolean = {
objective.isDefined && {
val objStr = objective.get.toString
objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" &&
objStr != "rank:pairwise")
}
}
/**
@ -157,9 +176,9 @@ object XGBoost extends Serializable {
*/
@deprecated(since = "0.7", message = "this method is deprecated since 0.7, users are encouraged" +
" to switch to trainWithRDD")
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
def train(trainingData: RDD[MLLabeledPoint], configMap: Map[String, Any], round: Int,
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
require(nWorkers > 0, "you must specify more than 0 workers")
trainWithRDD(trainingData, configMap, round, nWorkers, obj, eval, useExternalMemory, missing)
}
@ -180,10 +199,15 @@ object XGBoost extends Serializable {
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def trainWithRDD(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
def trainWithRDD(trainingData: RDD[MLLabeledPoint], configMap: Map[String, Any], round: Int,
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
require(nWorkers > 0, "you must specify more than 0 workers")
if (obj != null) {
require(configMap.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
" you have to specify the objective type as classification or regression with a" +
" customized objective function")
}
val tracker = new RabitTracker(nWorkers)
implicit val sc = trainingData.sparkContext
var overridedConfMap = configMap
@ -209,7 +233,13 @@ object XGBoost extends Serializable {
val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal")
if (returnVal == 0) {
boosters.first()
convertBoosterToXGBoostModel(boosters.first(),
isClassificationTask(
if (obj == null) {
configMap.get("objective")
} else {
configMap.get("obj_type")
}))
} else {
try {
if (sparkJobThread.isAlive) {
@ -223,6 +253,21 @@ object XGBoost extends Serializable {
}
}
private def loadGeneralModelParams(inputStream: FSDataInputStream): (String, String, String) = {
val featureCol = inputStream.readUTF()
val labelCol = inputStream.readUTF()
val predictionCol = inputStream.readUTF()
(featureCol, labelCol, predictionCol)
}
private def setGeneralModelParams(
featureCol: String, labelCol: String, predCol: String, xgBoostModel: XGBoostModel):
XGBoostModel = {
xgBoostModel.setFeaturesCol(featureCol)
xgBoostModel.setLabelCol(labelCol)
xgBoostModel.setPredictionCol(predCol)
}
/**
* Load XGBoost model from path in HDFS-compatible file system
*
@ -233,7 +278,29 @@ object XGBoost extends Serializable {
XGBoostModel = {
val path = new Path(modelPath)
val dataInStream = path.getFileSystem(sparkContext.hadoopConfiguration).open(path)
val xgBoostModel = new XGBoostModel(SXGBoost.loadModel(dataInStream))
xgBoostModel
val modelType = dataInStream.readUTF()
val (featureCol, labelCol, predictionCol) = loadGeneralModelParams(dataInStream)
modelType match {
case "_cls_" =>
val rawPredictionCol = dataInStream.readUTF()
val thresholdLength = dataInStream.readInt()
var thresholds: Array[Double] = null
if (thresholdLength != -1) {
thresholds = new Array[Double](thresholdLength)
for (i <- 0 until thresholdLength) {
thresholds(i) = dataInStream.readDouble()
}
}
val xgBoostModel = new XGBoostClassificationModel(SXGBoost.loadModel(dataInStream))
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel).
asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(rawPredictionCol)
if (thresholdLength != -1) {
xgBoostModel.setThresholds(thresholds)
}
xgBoostModel
case "_reg_" =>
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
}
}
}

View File

@ -0,0 +1,153 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.ml.linalg.{Vector => MLVector, DenseVector => MLDenseVector}
import org.apache.spark.ml.param.{DoubleArrayParam, Param, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
class XGBoostClassificationModel private[spark](
override val uid: String, _booster: Booster)
extends XGBoostModel(_booster) {
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), _booster)
// scalastyle:off
final val outputMargin: Param[Boolean] = new Param[Boolean](this, "outputMargin", "whether to output untransformed margin value ")
setDefault(outputMargin, false)
def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel]
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).")
setDefault(rawPredictionCol, "probabilities")
final def getRawPredictionCol: String = $(rawPredictionCol)
def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel]
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
def getThresholds: Array[Double] = $(thresholds)
def setThresholds(value: Array[Double]): XGBoostClassificationModel =
set(thresholds, value).asInstanceOf[XGBoostClassificationModel]
// scalastyle:on
private def predictRaw(
testSet: Dataset[_],
temporalColName: Option[String] = None,
forceTransformedScore: Option[Boolean] = None): DataFrame = {
val predictRDD = produceRowRDD(testSet, forceTransformedScore.getOrElse($(outputMargin)))
testSet.sparkSession.createDataFrame(predictRDD, schema = {
StructType(testSet.schema.add(StructField(
temporalColName.getOrElse($(rawPredictionCol)),
ArrayType(FloatType, containsNull = false), nullable = false)))
})
}
private def fromFeatureToPrediction(testSet: Dataset[_]): Dataset[_] = {
val rawPredictionDF = predictRaw(testSet, Some("rawPredictionCol"))
val predictionUDF = udf(raw2prediction _).apply(col("rawPredictionCol"))
val tempDF = rawPredictionDF.withColumn($(predictionCol), predictionUDF)
val allColumnNames = testSet.columns ++ Seq($(predictionCol))
tempDF.select(allColumnNames(0), allColumnNames.tail: _*)
}
private def argMax(vector: mutable.WrappedArray[Float]): Double = {
vector.zipWithIndex.maxBy(_._1)._2
}
private def raw2prediction(rawPrediction: mutable.WrappedArray[Float]): Double = {
if (!isDefined(thresholds)) {
argMax(rawPrediction)
} else {
probability2prediction(rawPrediction)
}
}
private def probability2prediction(probability: mutable.WrappedArray[Float]): Double = {
if (!isDefined(thresholds)) {
argMax(probability)
} else {
val thresholds: Array[Double] = getThresholds
val scaledProbability: mutable.WrappedArray[Double] =
probability.zip(thresholds).map { case (p, t) =>
if (t == 0.0) Double.PositiveInfinity else p / t
}
argMax(scaledProbability.map(_.toFloat))
}
}
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
transformSchema(testSet.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".transform() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
if ($(outputMargin)) {
setRawPredictionCol("margin")
}
var outputData = testSet
var numColsOutput = 0
if ($(rawPredictionCol).nonEmpty) {
outputData = predictRaw(testSet)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
if ($(rawPredictionCol).nonEmpty) {
require(!$(outputMargin), "XGBoost does not support output final prediction with" +
" untransformed margin. Please set predictionCol as \"\" when setting outputMargin as" +
" true")
val rawToPredUDF = udf(raw2prediction _).apply(col($(rawPredictionCol)))
outputData = outputData.withColumn($(predictionCol), rawToPredUDF)
} else {
outputData = fromFeatureToPrediction(testSet)
}
numColsOutput += 1
}
if (numColsOutput == 0) {
this.logWarning(s"$uid: XGBoostClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
outputData.toDF()
}
private[spark] var numOfClasses = 2
def numClasses: Int = numOfClasses
override def copy(extra: ParamMap): XGBoostClassificationModel = {
defaultCopy(extra)
}
override protected def predict(features: MLVector): Double = {
throw new Exception("XGBoost does not support online prediction ")
}
}

View File

@ -17,20 +17,18 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.spark.ml.{Predictor, Estimator}
import org.apache.spark.ml.Predictor
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector => MLVector, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{NumericType, DoubleType, StructType}
import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row}
import org.apache.spark.sql.types.{StructType, DoubleType}
import org.apache.spark.sql.{Dataset, Row}
/**
* the estimator wrapping XGBoost to produce a training model
*
* @param inputCol the name of input column
* @param labelCol the name of label column
* @param xgboostParams the parameters configuring XGBoost
* @param round the number of iterations to train
* @param nWorkers the total number of workers of xgboost
@ -39,43 +37,47 @@ import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row}
* @param useExternalMemory whether to use external memory when training
* @param missing the value taken as missing
*/
class XGBoostEstimator(
inputCol: String, labelCol: String,
xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN)
extends Estimator[XGBoostModel] {
override val uid: String = Identifiable.randomUID("XGBoostEstimator")
class XGBoostEstimator private[spark](
override val uid: String, xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean, missing: Float)
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel] {
def this(xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN) =
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any], round: Int,
nWorkers: Int, obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean,
missing: Float)
/**
* produce a XGBoostModel by fitting the given dataset
*/
def fit(trainingSet: Dataset[_]): XGBoostModel = {
override def train(trainingSet: Dataset[_]): XGBoostModel = {
val instances = trainingSet.select(
col(inputCol), col(labelCol).cast(DoubleType)).rdd.map {
case Row(feature: Vector, label: Double) =>
col($(featuresCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(feature: MLVector, label: Double) =>
LabeledPoint(label, feature)
}
transformSchema(trainingSet.schema, logging = true)
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, round, nWorkers, obj,
eval, useExternalMemory, missing).setParent(this)
copyValues(trainedModel)
val returnedModel = copyValues(trainedModel)
if (XGBoost.isClassificationTask(
if (obj == null) xgboostParams.get("objective") else xgboostParams.get("obj_type"))) {
val numClass = {
if (xgboostParams.contains("num_class")) {
xgboostParams("num_class").asInstanceOf[Int]
}
else {
2
}
}
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClass
}
returnedModel
}
override def copy(extra: ParamMap): Estimator[XGBoostModel] = {
override def copy(extra: ParamMap): XGBoostEstimator = {
defaultCopy(extra)
}
override def transformSchema(schema: StructType): StructType = {
// check input type, for now we only support vectorUDT as the input feature type
val inputType = schema(inputCol).dataType
require(inputType.equals(new VectorUDT), s"the type of input column $inputCol has to VectorUDT")
// check label Type,
val labelType = schema(labelCol).dataType
require(labelType.isInstanceOf[NumericType], s"the type of label column $labelCol has to" +
s" be NumericType")
schema
}
}

View File

@ -20,24 +20,48 @@ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{Model, PredictionModel}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{VectorUDT, DenseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.apache.spark.ml.PredictionModel
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql._
import org.apache.spark.sql.types.{FloatType, ArrayType, DataType}
import org.apache.spark.{SparkContext, TaskContext}
class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializable {
abstract class XGBoostModel(_booster: Booster)
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params {
var inputCol = "features"
var outputCol = "prediction"
var outputType: DataType = ArrayType(elementType = FloatType, containsNull = false)
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
// scalastyle:off
final val useExternalMemory: Param[Boolean] = new Param[Boolean](this, "useExternalMemory", "whether to use external memory for prediction")
setDefault(useExternalMemory, false)
def setExternalMemory(value: Boolean): XGBoostModel = set(useExternalMemory, value)
// scalastyle:on
/**
* Predict leaf instances with the given test set (represented as RDD)
*
* @param testSet test set represented as RDD
*/
def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
if (testSamples.hasNext) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predictLeaf(dMatrix))
} else {
Iterator()
}
}
}
/**
* evaluate XGBoostModel with a RDD-wrapped dataset
@ -53,24 +77,25 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
* @param useExternalCache if use external cache
* @return the average metric over all partitions
*/
def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null,
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
iter: Int = -1, useExternalCache: Boolean = false): String = {
require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter")
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val appName = evalDataset.context.appName
val allEvalMetrics = evalDataset.mapPartitions {
labeledPointsPartition =>
if (labeledPointsPartition.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
import DataUtils._
val cacheFileName = {
if (useExternalCache) {
s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}"
s"$appName-${TaskContext.get().stageId()}-$evalName" +
s"-deval_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
import DataUtils._
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
if (iter == -1) {
val predictions = broadcastBooster.value.predict(dMatrix)
@ -91,18 +116,48 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
s"$evalPrefix = $evalMetricMean"
}
/**
* Predict result with the given test set (represented as RDD)
*
* @param testSet test set represented as RDD
* @param missingValue the specified value to represent the missing value
*/
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val sampleArray = testSamples.toList
val numRows = sampleArray.size
val numColumns = sampleArray.head.size
if (numRows == 0) {
Iterator()
} else {
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
// translate to required format
val flatSampleArray = new Array[Float](numRows * numColumns)
for (i <- flatSampleArray.indices) {
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
}
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
Rabit.shutdown()
Iterator(broadcastBooster.value.predict(dMatrix))
}
}
}
/**
* Predict result with the given test set (represented as RDD)
*
* @param testSet test set represented as RDD
* @param useExternalCache whether to use external cache for the test set
*/
def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = {
import DataUtils._
def predict(testSet: RDD[MLVector], useExternalCache: Boolean = false):
RDD[Array[Array[Float]]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
val appName = testSet.context.appName
testSet.mapPartitions { testSamples =>
if (testSamples.hasNext) {
import DataUtils._
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val cacheFileName = {
@ -122,48 +177,76 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
}
}
protected def transformImpl(testSet: Dataset[_]): DataFrame
/**
* Predict result with the given test set (represented as RDD)
* append leaf index of each row as an additional column in the original dataset
*
* @param testSet test set represented as RDD
* @param missingValue the specified value to represent the missing value
* @return the original dataframe with an additional column containing prediction results
*/
def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val sampleArray = testSamples.toList
val numRows = sampleArray.size
val numColumns = sampleArray.head.size
if (numRows == 0) {
Iterator()
} else {
// translate to required format
val flatSampleArray = new Array[Float](numRows * numColumns)
for (i <- flatSampleArray.indices) {
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
def transformLeaf(testSet: Dataset[_]): DataFrame = {
val predictRDD = produceRowRDD(testSet, predLeaf = true)
setPredictionCol("predLeaf")
transformSchema(testSet.schema, logging = true)
testSet.sparkSession.createDataFrame(predictRDD, testSet.schema.add($(predictionCol),
ArrayType(FloatType, containsNull = false)))
}
protected def produceRowRDD(testSet: Dataset[_], outputMargin: Boolean = false,
predLeaf: Boolean = false): RDD[Row] = {
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
val appName = testSet.sparkSession.sparkContext.appName
testSet.rdd.mapPartitions {
rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[MLVector](
$(featuresCol))).toList.iterator
import DataUtils._
val cachePrefix = {
if ($(useExternalMemory)) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val testDataset = new DMatrix(vectorIterator, cachePrefix)
val rawPredictResults = {
if (!predLeaf) {
broadcastBooster.value.predict(testDataset, outputMargin).
map(Row(_)).iterator
} else {
broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator
}
}
Rabit.shutdown()
// concatenate original data partition and predictions
rowItr1.zip(rawPredictResults).map {
case (originalColumns: Row, predictColumn: Row) =>
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
}
} else {
Iterator[Row]()
}
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
Iterator(broadcastBooster.value.predict(dMatrix))
}
}
}
/**
* Predict leaf instances with the given test set (represented as RDD)
* produces the prediction results and append as an additional column in the original dataset
* NOTE: the prediction results is kept as the original format of xgboost
*
* @param testSet test set represented as RDD
* @return the original dataframe with an additional column containing prediction results
*/
def predictLeaves(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
if (testSamples.hasNext) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predictLeaf(dMatrix))
} else {
Iterator()
}
}
override def transform(testSet: Dataset[_]): DataFrame = {
transformImpl(testSet)
}
private def saveGeneralModelParam(outputStream: FSDataOutputStream): Unit = {
outputStream.writeUTF(getFeaturesCol)
outputStream.writeUTF(getLabelCol)
outputStream.writeUTF(getPredictionCol)
}
/**
@ -174,109 +257,34 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = {
val path = new Path(modelPath)
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
// output model type
this match {
case model: XGBoostClassificationModel =>
outputStream.writeUTF("_cls_")
saveGeneralModelParam(outputStream)
outputStream.writeUTF(model.getRawPredictionCol)
// threshold
// threshold length
if (!isDefined(model.thresholds)) {
outputStream.writeInt(-1)
} else {
val thresholdLength = model.getThresholds.length
outputStream.writeInt(thresholdLength)
for (i <- 0 until thresholdLength) {
outputStream.writeDouble(model.getThresholds(i))
}
}
case model: XGBoostRegressionModel =>
outputStream.writeUTF("_reg_")
// eventual prediction col
saveGeneralModelParam(outputStream)
}
// booster
_booster.saveModel(outputStream)
outputStream.close()
}
// override protected def featuresDataType: DataType = new VectorUDT
def booster: Booster = _booster
override val uid: String = Identifiable.randomUID("XGBoostModel")
override def copy(extra: ParamMap): XGBoostModel = {
defaultCopy(extra)
}
/**
* append leaf index of each row as an additional column in the original dataset
*
* @return the original dataframe with an additional column containing prediction results
*/
def transformLeaf(testSet: Dataset[_]): Unit = {
outputCol = "predLeaf"
transformSchema(testSet.schema, logging = true)
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
val instances = testSet.rdd.mapPartitions {
rowIterator =>
if (rowIterator.hasNext) {
val (rowItr1, rowItr2) = rowIterator.duplicate
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](inputCol)).
toList.iterator
import DataUtils._
val testDataset = new DMatrix(vectorIterator, null)
val rowPredictResults = broadcastBooster.value.predictLeaf(testDataset)
val predictResults = rowPredictResults.map(prediction => Row(prediction)).iterator
rowItr1.zip(predictResults).map {
case (originalColumns: Row, predictColumn: Row) =>
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
}
} else {
Iterator[Row]()
}
}
testSet.sparkSession.createDataFrame(instances, testSet.schema.add(outputCol, outputType)).
cache()
}
/**
* produces the prediction results and append as an additional column in the original dataset
* NOTE: the prediction results is kept as the original format of xgboost
*
* @return the original dataframe with an additional column containing prediction results
*/
override def transform(testSet: Dataset[_]): DataFrame = {
transform(testSet, None)
}
/**
* produces the prediction results and append as an additional column in the original dataset
* NOTE: the prediction results is transformed by applying the transformation function
* predictResultTrans to the original xgboost output
*
* @param rawPredictTransformer the function to transform xgboost output to the expected format
* @return the original dataframe with an additional column containing prediction results
*/
def transform(testSet: Dataset[_], rawPredictTransformer: Option[Array[Float] => DataType]):
DataFrame = {
transformSchema(testSet.schema, logging = true)
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
val instances = testSet.rdd.mapPartitions {
rowIterator =>
if (rowIterator.hasNext) {
val (rowItr1, rowItr2) = rowIterator.duplicate
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](inputCol)).
toList.iterator
import DataUtils._
val testDataset = new DMatrix(vectorIterator, null)
val rowPredictResults = broadcastBooster.value.predict(testDataset)
val predictResults = {
if (rawPredictTransformer.isDefined) {
rowPredictResults.map(prediction =>
Row(rawPredictTransformer.get(prediction))).iterator
} else {
rowPredictResults.map(prediction => Row(prediction)).iterator
}
}
rowItr1.zip(predictResults).map {
case (originalColumns: Row, predictColumn: Row) =>
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
}
} else {
Iterator[Row]()
}
}
testSet.sparkSession.createDataFrame(instances, testSet.schema.add(outputCol, outputType)).
cache()
}
@DeveloperApi
override def transformSchema(schema: StructType): StructType = {
if (schema.fieldNames.contains(outputCol)) {
throw new IllegalArgumentException(s"Output column $outputCol already exists.")
}
val inputType = schema(inputCol).dataType
require(inputType.equals(new VectorUDT),
s"the type of input column $inputCol has to be VectorUDT")
val outputFields = schema.fields :+ StructField(outputCol, outputType, nullable = false)
StructType(outputFields)
}
}

View File

@ -0,0 +1,48 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.ml.linalg.{Vector => MLVector}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
class XGBoostRegressionModel private[spark](override val uid: String, _booster: Booster)
extends XGBoostModel(_booster) {
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostRegressionModel"), _booster)
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
transformSchema(testSet.schema, logging = true)
val predictRDD = produceRowRDD(testSet)
testSet.sparkSession.createDataFrame(predictRDD, schema =
StructType(testSet.schema.add(StructField($(predictionCol),
ArrayType(FloatType, containsNull = false), nullable = false)))
)
}
override protected def predict(features: MLVector): Double = {
throw new Exception("XGBoost does not support online prediction for now")
}
override def copy(extra: ParamMap): XGBoostRegressionModel = {
defaultCopy(extra)
}
}

View File

@ -50,6 +50,8 @@ class EvalError extends EvalTrait {
logger.error(ex)
return -1f
}
require(predicts.length == labels.length, s"predicts length ${predicts.length} has to be" +
s" equal with label length ${labels.length}")
val nrow: Int = predicts.length
for (i <- 0 until nrow) {
if (labels(i) == 0.0 && predicts(i)(0) > 0) {

View File

@ -17,20 +17,21 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable {
trait SharedSparkContext extends FunSuite with BeforeAndAfterAll with Serializable {
@transient protected implicit var sc: SparkContext = null
before {
override def beforeAll() {
// build SparkContext
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
set("spark.driver.memory", "512m")
sc = new SparkContext(sparkConf)
sc.setLogLevel("ERROR")
}
after {
override def afterAll() {
if (sc != null) {
sc.stop()
}

View File

@ -21,17 +21,23 @@ import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import org.apache.commons.logging.LogFactory
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, Vector => SparkVector}
import org.apache.spark.rdd.RDD
trait Utils extends Serializable {
protected val numWorkers = Runtime.getRuntime().availableProcessors()
protected var labeledPointsRDD: RDD[LabeledPoint] = null
protected def cleanExternalCache(prefix: String): Unit = {
val dir = new File(".")
for (file <- dir.listFiles() if file.getName.startsWith(prefix)) {
file.delete()
}
}
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
@ -41,6 +47,15 @@ trait Utils extends Serializable {
sampleList.toList
}
protected def loadLabelAndVector(filePath: String): List[(Double, SparkVector)] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[(Double, SparkVector)]
for (sample <- file.getLines()) {
sampleList += fromSVMStringToLabelAndVector(sample)
}
sampleList.toList
}
protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = {
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures(0).toDouble
@ -59,7 +74,10 @@ trait Utils extends Serializable {
}
protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = {
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
sparkContext.parallelize(sampleList, numWorkers)
if (labeledPointsRDD == null) {
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
labeledPointsRDD = sparkContext.parallelize(sampleList, numWorkers)
}
labeledPointsRDD
}
}

View File

@ -0,0 +1,60 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.FunSuite
class XGBoostConfigureSuite extends FunSuite with Utils {
test("nthread configuration must be equal to spark.task.cpus") {
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
set("spark.task.cpus", "4")
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
// start another app
val trainingRDD = buildTrainingRDD(customSparkContext)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "nthread" -> 6)
intercept[IllegalArgumentException] {
XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
}
customSparkContext.stop()
}
test("kryoSerializer test") {
labeledPointsRDD = null
val eval = new EvalError()
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
val trainingRDD = buildTrainingRDD(customSparkContext)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
customSparkContext.stop()
}
}

View File

@ -25,77 +25,27 @@ import scala.io.Source
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.sql._
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
class XGBoostDFSuite extends SharedSparkContext with Utils {
private def loadRow(filePath: String): List[Row] = {
val file = Source.fromFile(new File(filePath))
val rowList = new ListBuffer[Row]
for (rowLine <- file.getLines()) {
rowList += fromSVMStringToRow(rowLine)
private var trainingDF: DataFrame = null
private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None): DataFrame = {
if (trainingDF == null) {
val rowList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
val labeledPointsRDD = sparkContext.getOrElse(sc).parallelize(rowList, numWorkers)
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
import sparkSession.implicits._
trainingDF = sparkSession.createDataset(labeledPointsRDD).toDF
}
rowList.toList
trainingDF
}
private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None):
DataFrame = {
val rowList = loadRow(getClass.getResource("/agaricus.txt.train").getFile)
val rowRDD = sparkContext.getOrElse(sc).parallelize(rowList, numWorkers)
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
sparkSession.createDataFrame(rowRDD,
StructType(Array(StructField("label", DoubleType, nullable = false),
StructField("features", new VectorUDT, nullable = false))))
}
private def fromSVMStringToRow(line: String): Row = {
val (label, sv) = fromSVMStringToLabelAndVector(line)
Row(label, sv)
}
test("test consistency between training with dataframe and RDD") {
val trainingDF = buildTrainingDataframe()
val trainingRDD = buildTrainingRDD(sc)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = false)
val xgBoostModelWithRDD = XGBoost.trainWithRDD(trainingRDD, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = false)
val eval = new EvalError()
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
assert(
eval.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) ===
eval.eval(xgBoostModelWithRDD.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix))
}
test("test transform of dataframe-based model") {
val trainingDF = buildTrainingDataframe()
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = false)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
val testRowsRDD = sc.parallelize(testSet.zipWithIndex, numWorkers).map{
case (instance: LabeledPoint, id: Int) =>
Row(id, instance.features, instance.label)
}
val testDF = trainingDF.sparkSession.createDataFrame(testRowsRDD, StructType(
Array(StructField("id", IntegerType),
StructField("features", new VectorUDT), StructField("label", DoubleType))))
xgBoostModelWithDF.transform(testDF).show()
}
test("test order preservation of dataframe-based model") {
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
"objective" -> "binary:logistic").toMap
test("test consistency and order preservation of dataframe-based model") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val trainingItr = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile).
iterator
val (testItr, auxTestItr) =
@ -105,25 +55,109 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5)
val predResultFromSeq = xgboostModel.predict(testDMatrix)
val testRowsRDD = sc.parallelize(
auxTestItr.toList.zipWithIndex, numWorkers).map {
val testSetItr = auxTestItr.zipWithIndex.map {
case (instance: LabeledPoint, id: Int) =>
Row(id, instance.features, instance.label)
(id, instance.features, instance.label)
}
val trainingDF = buildTrainingDataframe()
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = false)
val testDF = trainingDF.sqlContext.createDataFrame(testRowsRDD, StructType(
Array(StructField("id", IntegerType), StructField("features", new VectorUDT),
StructField("label", DoubleType))))
val predResultsFromDF =
xgBoostModelWithDF.transform(testDF).collect().map(row => (row.getAs[Int]("id"),
row.getAs[mutable.WrappedArray[Float]]("prediction"))).toMap
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
"id", "features", "label")
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
collect().map(row =>
(row.getAs[Int]("id"), row.getAs[mutable.WrappedArray[Float]]("probabilities"))
).toMap
assert(testDF.count() === predResultsFromDF.size)
for (i <- predResultFromSeq.indices) {
assert(predResultFromSeq(i).length === predResultsFromDF(i).length)
for (j <- predResultFromSeq(i).indices) {
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j))
}
}
cleanExternalCache("XGBoostDFSuite")
}
test("test transformLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
val trainingDF = buildTrainingDataframe()
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = false)
val testSetItr = testItr.zipWithIndex.map {
case (instance: LabeledPoint, id: Int) =>
(id, instance.features, instance.label)
}
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
"id", "features", "label")
xgBoostModelWithDF.transformLeaf(testDF).show()
}
test("test schema of XGBoostRegressionModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile).iterator.
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
(id, instance.features, instance.label)
}
val trainingDF = {
val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile)
val labeledPointsRDD = sc.parallelize(rowList, numWorkers)
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
import sparkSession.implicits._
sparkSession.createDataset(labeledPointsRDD).toDF
}
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = true)
xgBoostModelWithDF.setPredictionCol("final_prediction")
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
"id", "features", "label")
val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
assert(predictionDF.columns.contains("id") === true)
assert(predictionDF.columns.contains("features") === true)
assert(predictionDF.columns.contains("label") === true)
assert(predictionDF.columns.contains("final_prediction") === true)
predictionDF.show()
cleanExternalCache("XGBoostDFSuite")
}
test("test schema of XGBoostClassificationModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
(id, instance.features, instance.label)
}
val trainingDF = buildTrainingDataframe()
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = true)
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(
"raw_prediction").setPredictionCol("final_prediction")
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
"id", "features", "label")
var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
assert(predictionDF.columns.contains("id") === true)
assert(predictionDF.columns.contains("features") === true)
assert(predictionDF.columns.contains("label") === true)
assert(predictionDF.columns.contains("raw_prediction") === true)
assert(predictionDF.columns.contains("final_prediction") === true)
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("").
setPredictionCol("final_prediction")
predictionDF = xgBoostModelWithDF.transform(testDF)
assert(predictionDF.columns.contains("id") === true)
assert(predictionDF.columns.contains("features") === true)
assert(predictionDF.columns.contains("label") === true)
assert(predictionDF.columns.contains("raw_prediction") === false)
assert(predictionDF.columns.contains("final_prediction") === true)
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].
setRawPredictionCol("raw_prediction").setPredictionCol("")
predictionDF = xgBoostModelWithDF.transform(testDF)
assert(predictionDF.columns.contains("id") === true)
assert(predictionDF.columns.contains("features") === true)
assert(predictionDF.columns.contains("label") === true)
assert(predictionDF.columns.contains("raw_prediction") === true)
assert(predictionDF.columns.contains("final_prediction") === false)
cleanExternalCache("XGBoostDFSuite")
}
}

View File

@ -16,66 +16,47 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
import scala.collection.mutable.ListBuffer
import scala.util.Random
import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.mllib.linalg.{Vector => SparkVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val boosterRDD = XGBoost.buildDistributedBoosters(
trainingRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = false)
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
val boosters = boosterRDD.collect()
val eval = new EvalError()
for (booster <- boosters) {
// the threshold is 0.11 because it does not sync boosters with AllReduce
val predicts = booster.predict(testSetDMatrix, outPutMargin = true)
assert(eval.eval(predicts, testSetDMatrix) < 0.11)
}
cleanExternalCache("XGBoostSuite")
}
test("training with external memory cache") {
sc.stop()
sc = null
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
val eval = new EvalError()
val trainingRDD = buildTrainingRDD(customSparkContext)
val trainingRDD = buildTrainingRDD(sc)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
customSparkContext.stop()
// clean
val dir = new File(".")
for (file <- dir.listFiles() if file.getName.startsWith("XGBoostSuite-0-dtrain_cache")) {
file.delete()
}
cleanExternalCache("XGBoostSuite")
}
test("test with dense vectors containing missing value") {
@ -106,10 +87,13 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
}
val trainingRDD = buildDenseRDD().repartition(4)
val testRDD = buildDenseRDD().repartition(4)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers,
useExternalMemory = true)
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
// clean
cleanExternalCache("XGBoostSuite")
}
test("test consistency of prediction functions with RDD") {
@ -120,11 +104,12 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
for (i <- testSet.indices) {
assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values))
}
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1 = predRDD.collect()(0)
assert(testRDD.count() === predResult1.length)
import DataUtils._
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
for (i <- predResult1.indices; j <- predResult1(i).indices) {
@ -134,9 +119,9 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
test("test eval functions with RDD") {
val trainingRDD = buildTrainingRDD(sc).cache()
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers)
xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false)
}
@ -150,7 +135,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val testRDD = buildEmptyRDD()
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
println(xgBoostModel.predict(testRDD).collect().length === 0)
@ -164,8 +149,8 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
@ -177,41 +162,40 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
assert(loadedEvalResults == evalResults)
}
test("nthread configuration must be equal to spark.task.cpus") {
sc.stop()
sc = null
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
set("spark.task.cpus", "4")
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
// start another app
val trainingRDD = buildTrainingRDD(customSparkContext)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic", "nthread" -> 6).toMap
intercept[IllegalArgumentException] {
XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
}
customSparkContext.stop()
}
test("kryoSerializer test") {
sc.stop()
sc = null
val eval = new EvalError()
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val customSparkContext = new SparkContext(sparkConf)
customSparkContext.setLogLevel("ERROR")
val trainingRDD = buildTrainingRDD(customSparkContext)
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
customSparkContext.stop()
test("test save and load of different types of models") {
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val trainingRDD = buildTrainingRDD(sc)
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
// validate regression model
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.setFeaturesCol("feature_col")
xgBoostModel.setLabelCol("label_col")
xgBoostModel.setPredictionCol("prediction_col")
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
assert(loadedXGBoostModel.getLabelCol == "label_col")
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
// classification model
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5).deep)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
}
}

View File

@ -38,6 +38,8 @@ trait EvalTrait extends IEvaluation {
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
private[scala] def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
require(predicts.length == jdmat.getLabel.length, "predicts size and label size must match " +
s" predicts size: ${predicts.length}, label size: ${jdmat.getLabel.length}")
eval(predicts, new DMatrix(jdmat))
}
}