[jvm-packages] separate classification and regression model and integrate with ML package (#1608)
This commit is contained in:
parent
3b9987ca9c
commit
1673bcbe7e
@ -49,7 +49,7 @@ addons:
|
|||||||
before_install:
|
before_install:
|
||||||
- source dmlc-core/scripts/travis/travis_setup_env.sh
|
- source dmlc-core/scripts/travis/travis_setup_env.sh
|
||||||
- export PYTHONPATH=${PYTHONPATH}:${PWD}/python-package
|
- export PYTHONPATH=${PYTHONPATH}:${PWD}/python-package
|
||||||
- echo "MAVEN_OPTS='-Xmx2048m -XX:MaxPermSize=1024m -XX:ReservedCodeCacheSize=512m'" > ~/.mavenrc
|
- echo "MAVEN_OPTS='-Xmx2048m -XX:MaxPermSize=1024m -XX:ReservedCodeCacheSize=512m -Dorg.slf4j.simpleLogger.defaultLogLevel=error'" > ~/.mavenrc
|
||||||
|
|
||||||
install:
|
install:
|
||||||
- source tests/travis/setup.sh
|
- source tests/travis/setup.sh
|
||||||
|
|||||||
@ -31,6 +31,11 @@ mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
|
|||||||
cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py
|
cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py
|
||||||
# copy test data files
|
# copy test data files
|
||||||
mkdir -p xgboost4j-spark/src/test/resources/
|
mkdir -p xgboost4j-spark/src/test/resources/
|
||||||
|
cd ../demo/regression
|
||||||
|
python mapfeat.py
|
||||||
|
python mknfold.py machine.txt 1
|
||||||
|
cd -
|
||||||
|
cp ../demo/regression/machine.txt.t* xgboost4j-spark/src/test/resources/
|
||||||
cp ../demo/data/agaricus.* xgboost4j-spark/src/test/resources/
|
cp ../demo/data/agaricus.* xgboost4j-spark/src/test/resources/
|
||||||
popd > /dev/null
|
popd > /dev/null
|
||||||
echo "complete"
|
echo "complete"
|
||||||
|
|||||||
@ -20,6 +20,8 @@ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
|||||||
import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
|
import ml.dmlc.xgboost4j.scala.spark.{DataUtils, XGBoost}
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
|
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector}
|
||||||
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
|
|
||||||
object SparkWithRDD {
|
object SparkWithRDD {
|
||||||
def main(args: Array[String]): Unit = {
|
def main(args: Array[String]): Unit = {
|
||||||
@ -38,8 +40,10 @@ object SparkWithRDD {
|
|||||||
// number of iterations
|
// number of iterations
|
||||||
val numRound = args(0).toInt
|
val numRound = args(0).toInt
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath)
|
val trainRDD = MLUtils.loadLibSVMFile(sc, inputTrainPath).map(lp =>
|
||||||
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath).collect().iterator
|
MLLabeledPoint(lp.label, new MLDenseVector(lp.features.toArray)))
|
||||||
|
val testSet = MLUtils.loadLibSVMFile(sc, inputTestPath).collect().map(
|
||||||
|
lp => new MLDenseVector(lp.features.toArray)).iterator
|
||||||
// training parameters
|
// training parameters
|
||||||
val paramMap = List(
|
val paramMap = List(
|
||||||
"eta" -> 0.1f,
|
"eta" -> 0.1f,
|
||||||
|
|||||||
@ -19,16 +19,17 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
||||||
|
|
||||||
object DataUtils extends Serializable {
|
object DataUtils extends Serializable {
|
||||||
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[SparkLabeledPoint])
|
|
||||||
|
implicit def fromSparkPointsToXGBoostPointsJava(sps: Iterator[MLLabeledPoint])
|
||||||
: java.util.Iterator[LabeledPoint] = {
|
: java.util.Iterator[LabeledPoint] = {
|
||||||
fromSparkPointsToXGBoostPoints(sps).asJava
|
fromSparkPointsToXGBoostPoints(sps).asJava
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[SparkLabeledPoint]):
|
implicit def fromSparkPointsToXGBoostPoints(sps: Iterator[MLLabeledPoint]):
|
||||||
Iterator[LabeledPoint] = {
|
Iterator[LabeledPoint] = {
|
||||||
for (p <- sps) yield {
|
for (p <- sps) yield {
|
||||||
p.features match {
|
p.features match {
|
||||||
@ -45,6 +46,7 @@ object DataUtils extends Serializable {
|
|||||||
: java.util.Iterator[LabeledPoint] = {
|
: java.util.Iterator[LabeledPoint] = {
|
||||||
fromSparkVectorToXGBoostPoints(sps).asJava
|
fromSparkVectorToXGBoostPoints(sps).asJava
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit def fromSparkVectorToXGBoostPoints(sps: Iterator[Vector])
|
implicit def fromSparkVectorToXGBoostPoints(sps: Iterator[Vector])
|
||||||
: Iterator[LabeledPoint] = {
|
: Iterator[LabeledPoint] = {
|
||||||
for (p <- sps) yield {
|
for (p <- sps) yield {
|
||||||
|
|||||||
@ -23,26 +23,30 @@ import scala.collection.mutable.ListBuffer
|
|||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.{FSDataInputStream, Path}
|
||||||
import org.apache.spark.mllib.linalg.SparseVector
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.Dataset
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
private implicit def convertBoosterToXGBoostModel(booster: Booster)
|
private def convertBoosterToXGBoostModel(booster: Booster, isClassification: Boolean):
|
||||||
(implicit sc: SparkContext): XGBoostModel = {
|
XGBoostModel = {
|
||||||
new XGBoostModel(booster)
|
if (!isClassification) {
|
||||||
|
new XGBoostRegressionModel(booster)
|
||||||
|
} else {
|
||||||
|
new XGBoostClassificationModel(booster)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def fromDenseToSparseLabeledPoints(
|
private def fromDenseToSparseLabeledPoints(
|
||||||
denseLabeledPoints: Iterator[LabeledPoint],
|
denseLabeledPoints: Iterator[MLLabeledPoint],
|
||||||
missing: Float): Iterator[LabeledPoint] = {
|
missing: Float): Iterator[MLLabeledPoint] = {
|
||||||
if (!missing.isNaN) {
|
if (!missing.isNaN) {
|
||||||
val sparseLabeledPoints = new ListBuffer[LabeledPoint]
|
val sparseLabeledPoints = new ListBuffer[MLLabeledPoint]
|
||||||
for (labelPoint <- denseLabeledPoints) {
|
for (labelPoint <- denseLabeledPoints) {
|
||||||
val dVector = labelPoint.features.toDense
|
val dVector = labelPoint.features.toDense
|
||||||
val indices = new ListBuffer[Int]
|
val indices = new ListBuffer[Int]
|
||||||
@ -55,7 +59,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
val sparseVector = new SparseVector(dVector.values.length, indices.toArray,
|
val sparseVector = new SparseVector(dVector.values.length, indices.toArray,
|
||||||
values.toArray)
|
values.toArray)
|
||||||
sparseLabeledPoints += LabeledPoint(labelPoint.label, sparseVector)
|
sparseLabeledPoints += MLLabeledPoint(labelPoint.label, sparseVector)
|
||||||
}
|
}
|
||||||
sparseLabeledPoints.iterator
|
sparseLabeledPoints.iterator
|
||||||
} else {
|
} else {
|
||||||
@ -64,7 +68,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def buildDistributedBoosters(
|
private[spark] def buildDistributedBoosters(
|
||||||
trainingData: RDD[LabeledPoint],
|
trainingData: RDD[MLLabeledPoint],
|
||||||
xgBoostConfMap: Map[String, Any],
|
xgBoostConfMap: Map[String, Any],
|
||||||
rabitEnv: mutable.Map[String, String],
|
rabitEnv: mutable.Map[String, String],
|
||||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||||
@ -124,20 +128,35 @@ object XGBoost extends Serializable {
|
|||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @param missing the value represented the missing value in the dataset
|
* @param missing the value represented the missing value in the dataset
|
||||||
* @param inputCol the name of input column, "features" as default value
|
* @param featureCol the name of input column, "features" as default value
|
||||||
* @param labelCol the name of output column, "label" as default value
|
* @param labelCol the name of output column, "label" as default value
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||||
* @return XGBoostModel when successful training
|
* @return XGBoostModel when successful training
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def trainWithDataFrame(trainingData: Dataset[_],
|
def trainWithDataFrame(
|
||||||
params: Map[String, Any], round: Int,
|
trainingData: Dataset[_],
|
||||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
params: Map[String, Any],
|
||||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN,
|
round: Int,
|
||||||
inputCol: String = "features", labelCol: String = "label"): XGBoostModel = {
|
nWorkers: Int,
|
||||||
|
obj: ObjectiveTrait = null,
|
||||||
|
eval: EvalTrait = null,
|
||||||
|
useExternalMemory: Boolean = false,
|
||||||
|
missing: Float = Float.NaN,
|
||||||
|
featureCol: String = "features",
|
||||||
|
labelCol: String = "label"): XGBoostModel = {
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
new XGBoostEstimator(inputCol, labelCol, params, round, nWorkers, obj, eval,
|
val estimator = new XGBoostEstimator(params, round, nWorkers, obj, eval,
|
||||||
useExternalMemory, missing).fit(trainingData)
|
useExternalMemory, missing)
|
||||||
|
estimator.setFeaturesCol(featureCol).setLabelCol(labelCol).fit(trainingData)
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] def isClassificationTask(objective: Option[Any]): Boolean = {
|
||||||
|
objective.isDefined && {
|
||||||
|
val objStr = objective.get.toString
|
||||||
|
objStr == "classification" || (!objStr.startsWith("reg:") && objStr != "count:poisson" &&
|
||||||
|
objStr != "rank:pairwise")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -157,9 +176,9 @@ object XGBoost extends Serializable {
|
|||||||
*/
|
*/
|
||||||
@deprecated(since = "0.7", message = "this method is deprecated since 0.7, users are encouraged" +
|
@deprecated(since = "0.7", message = "this method is deprecated since 0.7, users are encouraged" +
|
||||||
" to switch to trainWithRDD")
|
" to switch to trainWithRDD")
|
||||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
def train(trainingData: RDD[MLLabeledPoint], configMap: Map[String, Any], round: Int,
|
||||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
trainWithRDD(trainingData, configMap, round, nWorkers, obj, eval, useExternalMemory, missing)
|
trainWithRDD(trainingData, configMap, round, nWorkers, obj, eval, useExternalMemory, missing)
|
||||||
}
|
}
|
||||||
@ -180,10 +199,15 @@ object XGBoost extends Serializable {
|
|||||||
* @return XGBoostModel when successful training
|
* @return XGBoostModel when successful training
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def trainWithRDD(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
def trainWithRDD(trainingData: RDD[MLLabeledPoint], configMap: Map[String, Any], round: Int,
|
||||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
|
if (obj != null) {
|
||||||
|
require(configMap.get("obj_type").isDefined, "parameter \"obj_type\" is not defined," +
|
||||||
|
" you have to specify the objective type as classification or regression with a" +
|
||||||
|
" customized objective function")
|
||||||
|
}
|
||||||
val tracker = new RabitTracker(nWorkers)
|
val tracker = new RabitTracker(nWorkers)
|
||||||
implicit val sc = trainingData.sparkContext
|
implicit val sc = trainingData.sparkContext
|
||||||
var overridedConfMap = configMap
|
var overridedConfMap = configMap
|
||||||
@ -209,7 +233,13 @@ object XGBoost extends Serializable {
|
|||||||
val returnVal = tracker.waitFor()
|
val returnVal = tracker.waitFor()
|
||||||
logger.info(s"Rabit returns with exit code $returnVal")
|
logger.info(s"Rabit returns with exit code $returnVal")
|
||||||
if (returnVal == 0) {
|
if (returnVal == 0) {
|
||||||
boosters.first()
|
convertBoosterToXGBoostModel(boosters.first(),
|
||||||
|
isClassificationTask(
|
||||||
|
if (obj == null) {
|
||||||
|
configMap.get("objective")
|
||||||
|
} else {
|
||||||
|
configMap.get("obj_type")
|
||||||
|
}))
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
if (sparkJobThread.isAlive) {
|
if (sparkJobThread.isAlive) {
|
||||||
@ -223,6 +253,21 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def loadGeneralModelParams(inputStream: FSDataInputStream): (String, String, String) = {
|
||||||
|
val featureCol = inputStream.readUTF()
|
||||||
|
val labelCol = inputStream.readUTF()
|
||||||
|
val predictionCol = inputStream.readUTF()
|
||||||
|
(featureCol, labelCol, predictionCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def setGeneralModelParams(
|
||||||
|
featureCol: String, labelCol: String, predCol: String, xgBoostModel: XGBoostModel):
|
||||||
|
XGBoostModel = {
|
||||||
|
xgBoostModel.setFeaturesCol(featureCol)
|
||||||
|
xgBoostModel.setLabelCol(labelCol)
|
||||||
|
xgBoostModel.setPredictionCol(predCol)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load XGBoost model from path in HDFS-compatible file system
|
* Load XGBoost model from path in HDFS-compatible file system
|
||||||
*
|
*
|
||||||
@ -233,7 +278,29 @@ object XGBoost extends Serializable {
|
|||||||
XGBoostModel = {
|
XGBoostModel = {
|
||||||
val path = new Path(modelPath)
|
val path = new Path(modelPath)
|
||||||
val dataInStream = path.getFileSystem(sparkContext.hadoopConfiguration).open(path)
|
val dataInStream = path.getFileSystem(sparkContext.hadoopConfiguration).open(path)
|
||||||
val xgBoostModel = new XGBoostModel(SXGBoost.loadModel(dataInStream))
|
val modelType = dataInStream.readUTF()
|
||||||
xgBoostModel
|
val (featureCol, labelCol, predictionCol) = loadGeneralModelParams(dataInStream)
|
||||||
|
modelType match {
|
||||||
|
case "_cls_" =>
|
||||||
|
val rawPredictionCol = dataInStream.readUTF()
|
||||||
|
val thresholdLength = dataInStream.readInt()
|
||||||
|
var thresholds: Array[Double] = null
|
||||||
|
if (thresholdLength != -1) {
|
||||||
|
thresholds = new Array[Double](thresholdLength)
|
||||||
|
for (i <- 0 until thresholdLength) {
|
||||||
|
thresholds(i) = dataInStream.readDouble()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val xgBoostModel = new XGBoostClassificationModel(SXGBoost.loadModel(dataInStream))
|
||||||
|
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel).
|
||||||
|
asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(rawPredictionCol)
|
||||||
|
if (thresholdLength != -1) {
|
||||||
|
xgBoostModel.setThresholds(thresholds)
|
||||||
|
}
|
||||||
|
xgBoostModel
|
||||||
|
case "_reg_" =>
|
||||||
|
val xgBoostModel = new XGBoostRegressionModel(SXGBoost.loadModel(dataInStream))
|
||||||
|
setGeneralModelParams(featureCol, labelCol, predictionCol, xgBoostModel)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,153 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
|
import org.apache.spark.ml.linalg.{Vector => MLVector, DenseVector => MLDenseVector}
|
||||||
|
import org.apache.spark.ml.param.{DoubleArrayParam, Param, ParamMap}
|
||||||
|
import org.apache.spark.ml.util.Identifiable
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
|
|
||||||
|
class XGBoostClassificationModel private[spark](
|
||||||
|
override val uid: String, _booster: Booster)
|
||||||
|
extends XGBoostModel(_booster) {
|
||||||
|
|
||||||
|
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostClassificationModel"), _booster)
|
||||||
|
|
||||||
|
// scalastyle:off
|
||||||
|
|
||||||
|
final val outputMargin: Param[Boolean] = new Param[Boolean](this, "outputMargin", "whether to output untransformed margin value ")
|
||||||
|
|
||||||
|
setDefault(outputMargin, false)
|
||||||
|
|
||||||
|
def setOutputMargin(value: Boolean): XGBoostModel = set(outputMargin, value).asInstanceOf[XGBoostClassificationModel]
|
||||||
|
|
||||||
|
final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "Column name for raw prediction output of xgboost. If outputMargin is true, the column contains untransformed margin value; otherwise it is the probability for each class (by default).")
|
||||||
|
|
||||||
|
setDefault(rawPredictionCol, "probabilities")
|
||||||
|
|
||||||
|
final def getRawPredictionCol: String = $(rawPredictionCol)
|
||||||
|
|
||||||
|
def setRawPredictionCol(value: String): XGBoostClassificationModel = set(rawPredictionCol, value).asInstanceOf[XGBoostClassificationModel]
|
||||||
|
|
||||||
|
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
|
||||||
|
|
||||||
|
def getThresholds: Array[Double] = $(thresholds)
|
||||||
|
|
||||||
|
def setThresholds(value: Array[Double]): XGBoostClassificationModel =
|
||||||
|
set(thresholds, value).asInstanceOf[XGBoostClassificationModel]
|
||||||
|
|
||||||
|
// scalastyle:on
|
||||||
|
|
||||||
|
private def predictRaw(
|
||||||
|
testSet: Dataset[_],
|
||||||
|
temporalColName: Option[String] = None,
|
||||||
|
forceTransformedScore: Option[Boolean] = None): DataFrame = {
|
||||||
|
val predictRDD = produceRowRDD(testSet, forceTransformedScore.getOrElse($(outputMargin)))
|
||||||
|
testSet.sparkSession.createDataFrame(predictRDD, schema = {
|
||||||
|
StructType(testSet.schema.add(StructField(
|
||||||
|
temporalColName.getOrElse($(rawPredictionCol)),
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private def fromFeatureToPrediction(testSet: Dataset[_]): Dataset[_] = {
|
||||||
|
val rawPredictionDF = predictRaw(testSet, Some("rawPredictionCol"))
|
||||||
|
val predictionUDF = udf(raw2prediction _).apply(col("rawPredictionCol"))
|
||||||
|
val tempDF = rawPredictionDF.withColumn($(predictionCol), predictionUDF)
|
||||||
|
val allColumnNames = testSet.columns ++ Seq($(predictionCol))
|
||||||
|
tempDF.select(allColumnNames(0), allColumnNames.tail: _*)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def argMax(vector: mutable.WrappedArray[Float]): Double = {
|
||||||
|
vector.zipWithIndex.maxBy(_._1)._2
|
||||||
|
}
|
||||||
|
|
||||||
|
private def raw2prediction(rawPrediction: mutable.WrappedArray[Float]): Double = {
|
||||||
|
if (!isDefined(thresholds)) {
|
||||||
|
argMax(rawPrediction)
|
||||||
|
} else {
|
||||||
|
probability2prediction(rawPrediction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def probability2prediction(probability: mutable.WrappedArray[Float]): Double = {
|
||||||
|
if (!isDefined(thresholds)) {
|
||||||
|
argMax(probability)
|
||||||
|
} else {
|
||||||
|
val thresholds: Array[Double] = getThresholds
|
||||||
|
val scaledProbability: mutable.WrappedArray[Double] =
|
||||||
|
probability.zip(thresholds).map { case (p, t) =>
|
||||||
|
if (t == 0.0) Double.PositiveInfinity else p / t
|
||||||
|
}
|
||||||
|
argMax(scaledProbability.map(_.toFloat))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
|
||||||
|
transformSchema(testSet.schema, logging = true)
|
||||||
|
if (isDefined(thresholds)) {
|
||||||
|
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||||
|
".transform() called with non-matching numClasses and thresholds.length." +
|
||||||
|
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||||
|
}
|
||||||
|
if ($(outputMargin)) {
|
||||||
|
setRawPredictionCol("margin")
|
||||||
|
}
|
||||||
|
var outputData = testSet
|
||||||
|
var numColsOutput = 0
|
||||||
|
if ($(rawPredictionCol).nonEmpty) {
|
||||||
|
outputData = predictRaw(testSet)
|
||||||
|
numColsOutput += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($(predictionCol).nonEmpty) {
|
||||||
|
if ($(rawPredictionCol).nonEmpty) {
|
||||||
|
require(!$(outputMargin), "XGBoost does not support output final prediction with" +
|
||||||
|
" untransformed margin. Please set predictionCol as \"\" when setting outputMargin as" +
|
||||||
|
" true")
|
||||||
|
val rawToPredUDF = udf(raw2prediction _).apply(col($(rawPredictionCol)))
|
||||||
|
outputData = outputData.withColumn($(predictionCol), rawToPredUDF)
|
||||||
|
} else {
|
||||||
|
outputData = fromFeatureToPrediction(testSet)
|
||||||
|
}
|
||||||
|
numColsOutput += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if (numColsOutput == 0) {
|
||||||
|
this.logWarning(s"$uid: XGBoostClassificationModel.transform() was called as NOOP" +
|
||||||
|
" since no output columns were set.")
|
||||||
|
}
|
||||||
|
outputData.toDF()
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] var numOfClasses = 2
|
||||||
|
|
||||||
|
def numClasses: Int = numOfClasses
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): XGBoostClassificationModel = {
|
||||||
|
defaultCopy(extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def predict(features: MLVector): Double = {
|
||||||
|
throw new Exception("XGBoost does not support online prediction ")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -17,20 +17,18 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import org.apache.spark.ml.{Predictor, Estimator}
|
import org.apache.spark.ml.Predictor
|
||||||
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
|
import org.apache.spark.ml.linalg.{Vector => MLVector, VectorUDT}
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.util.Identifiable
|
import org.apache.spark.ml.util.Identifiable
|
||||||
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
|
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.{NumericType, DoubleType, StructType}
|
import org.apache.spark.sql.types.{StructType, DoubleType}
|
||||||
import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* the estimator wrapping XGBoost to produce a training model
|
* the estimator wrapping XGBoost to produce a training model
|
||||||
*
|
*
|
||||||
* @param inputCol the name of input column
|
|
||||||
* @param labelCol the name of label column
|
|
||||||
* @param xgboostParams the parameters configuring XGBoost
|
* @param xgboostParams the parameters configuring XGBoost
|
||||||
* @param round the number of iterations to train
|
* @param round the number of iterations to train
|
||||||
* @param nWorkers the total number of workers of xgboost
|
* @param nWorkers the total number of workers of xgboost
|
||||||
@ -39,43 +37,47 @@ import org.apache.spark.sql.{DataFrame, TypedColumn, Dataset, Row}
|
|||||||
* @param useExternalMemory whether to use external memory when training
|
* @param useExternalMemory whether to use external memory when training
|
||||||
* @param missing the value taken as missing
|
* @param missing the value taken as missing
|
||||||
*/
|
*/
|
||||||
class XGBoostEstimator(
|
class XGBoostEstimator private[spark](
|
||||||
inputCol: String, labelCol: String,
|
override val uid: String, xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
|
||||||
xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
|
obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean, missing: Float)
|
||||||
obj: ObjectiveTrait = null,
|
extends Predictor[MLVector, XGBoostEstimator, XGBoostModel] {
|
||||||
eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN)
|
|
||||||
extends Estimator[XGBoostModel] {
|
|
||||||
|
|
||||||
override val uid: String = Identifiable.randomUID("XGBoostEstimator")
|
|
||||||
|
|
||||||
|
def this(xgboostParams: Map[String, Any], round: Int, nWorkers: Int,
|
||||||
|
obj: ObjectiveTrait = null,
|
||||||
|
eval: EvalTrait = null, useExternalMemory: Boolean = false, missing: Float = Float.NaN) =
|
||||||
|
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any], round: Int,
|
||||||
|
nWorkers: Int, obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean,
|
||||||
|
missing: Float)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* produce a XGBoostModel by fitting the given dataset
|
* produce a XGBoostModel by fitting the given dataset
|
||||||
*/
|
*/
|
||||||
def fit(trainingSet: Dataset[_]): XGBoostModel = {
|
override def train(trainingSet: Dataset[_]): XGBoostModel = {
|
||||||
val instances = trainingSet.select(
|
val instances = trainingSet.select(
|
||||||
col(inputCol), col(labelCol).cast(DoubleType)).rdd.map {
|
col($(featuresCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
|
||||||
case Row(feature: Vector, label: Double) =>
|
case Row(feature: MLVector, label: Double) =>
|
||||||
LabeledPoint(label, feature)
|
LabeledPoint(label, feature)
|
||||||
}
|
}
|
||||||
transformSchema(trainingSet.schema, logging = true)
|
transformSchema(trainingSet.schema, logging = true)
|
||||||
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, round, nWorkers, obj,
|
val trainedModel = XGBoost.trainWithRDD(instances, xgboostParams, round, nWorkers, obj,
|
||||||
eval, useExternalMemory, missing).setParent(this)
|
eval, useExternalMemory, missing).setParent(this)
|
||||||
copyValues(trainedModel)
|
val returnedModel = copyValues(trainedModel)
|
||||||
|
if (XGBoost.isClassificationTask(
|
||||||
|
if (obj == null) xgboostParams.get("objective") else xgboostParams.get("obj_type"))) {
|
||||||
|
val numClass = {
|
||||||
|
if (xgboostParams.contains("num_class")) {
|
||||||
|
xgboostParams("num_class").asInstanceOf[Int]
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = numClass
|
||||||
|
}
|
||||||
|
returnedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): Estimator[XGBoostModel] = {
|
override def copy(extra: ParamMap): XGBoostEstimator = {
|
||||||
defaultCopy(extra)
|
defaultCopy(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
|
||||||
// check input type, for now we only support vectorUDT as the input feature type
|
|
||||||
val inputType = schema(inputCol).dataType
|
|
||||||
require(inputType.equals(new VectorUDT), s"the type of input column $inputCol has to VectorUDT")
|
|
||||||
// check label Type,
|
|
||||||
val labelType = schema(labelCol).dataType
|
|
||||||
require(labelType.isInstanceOf[NumericType], s"the type of label column $labelCol has to" +
|
|
||||||
s" be NumericType")
|
|
||||||
schema
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,24 +20,48 @@ import scala.collection.JavaConverters._
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait}
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.ml.PredictionModel
|
||||||
import org.apache.spark.ml.{Model, PredictionModel}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.linalg.{DenseVector => MLDenseVector, Vector => MLVector}
|
||||||
import org.apache.spark.ml.util.Identifiable
|
import org.apache.spark.ml.param.{Param, Params}
|
||||||
import org.apache.spark.mllib.linalg.{VectorUDT, DenseVector, Vector}
|
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types.{FloatType, ArrayType, DataType}
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
|
|
||||||
class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializable {
|
abstract class XGBoostModel(_booster: Booster)
|
||||||
|
extends PredictionModel[MLVector, XGBoostModel] with Serializable with Params {
|
||||||
|
|
||||||
var inputCol = "features"
|
def setLabelCol(name: String): XGBoostModel = set(labelCol, name)
|
||||||
var outputCol = "prediction"
|
|
||||||
var outputType: DataType = ArrayType(elementType = FloatType, containsNull = false)
|
// scalastyle:off
|
||||||
|
|
||||||
|
final val useExternalMemory: Param[Boolean] = new Param[Boolean](this, "useExternalMemory", "whether to use external memory for prediction")
|
||||||
|
|
||||||
|
setDefault(useExternalMemory, false)
|
||||||
|
|
||||||
|
def setExternalMemory(value: Boolean): XGBoostModel = set(useExternalMemory, value)
|
||||||
|
|
||||||
|
// scalastyle:on
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict leaf instances with the given test set (represented as RDD)
|
||||||
|
*
|
||||||
|
* @param testSet test set represented as RDD
|
||||||
|
*/
|
||||||
|
def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Array[Float]]] = {
|
||||||
|
import DataUtils._
|
||||||
|
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||||
|
testSet.mapPartitions { testSamples =>
|
||||||
|
if (testSamples.hasNext) {
|
||||||
|
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
||||||
|
Iterator(broadcastBooster.value.predictLeaf(dMatrix))
|
||||||
|
} else {
|
||||||
|
Iterator()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* evaluate XGBoostModel with a RDD-wrapped dataset
|
* evaluate XGBoostModel with a RDD-wrapped dataset
|
||||||
@ -53,24 +77,25 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
|||||||
* @param useExternalCache if use external cache
|
* @param useExternalCache if use external cache
|
||||||
* @return the average metric over all partitions
|
* @return the average metric over all partitions
|
||||||
*/
|
*/
|
||||||
def eval(evalDataset: RDD[LabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
|
||||||
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
iter: Int = -1, useExternalCache: Boolean = false): String = {
|
||||||
require(evalFunc != null || iter != -1, "you have to specify value of either eval or iter")
|
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
|
||||||
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
|
||||||
val appName = evalDataset.context.appName
|
val appName = evalDataset.context.appName
|
||||||
val allEvalMetrics = evalDataset.mapPartitions {
|
val allEvalMetrics = evalDataset.mapPartitions {
|
||||||
labeledPointsPartition =>
|
labeledPointsPartition =>
|
||||||
if (labeledPointsPartition.hasNext) {
|
if (labeledPointsPartition.hasNext) {
|
||||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
import DataUtils._
|
|
||||||
val cacheFileName = {
|
val cacheFileName = {
|
||||||
if (useExternalCache) {
|
if (useExternalCache) {
|
||||||
s"$appName-${TaskContext.get().stageId()}-deval_cache-${TaskContext.getPartitionId()}"
|
s"$appName-${TaskContext.get().stageId()}-$evalName" +
|
||||||
|
s"-deval_cache-${TaskContext.getPartitionId()}"
|
||||||
} else {
|
} else {
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
import DataUtils._
|
||||||
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
|
||||||
if (iter == -1) {
|
if (iter == -1) {
|
||||||
val predictions = broadcastBooster.value.predict(dMatrix)
|
val predictions = broadcastBooster.value.predict(dMatrix)
|
||||||
@ -91,18 +116,48 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
|||||||
s"$evalPrefix = $evalMetricMean"
|
s"$evalPrefix = $evalMetricMean"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predict result with the given test set (represented as RDD)
|
||||||
|
*
|
||||||
|
* @param testSet test set represented as RDD
|
||||||
|
* @param missingValue the specified value to represent the missing value
|
||||||
|
*/
|
||||||
|
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
|
||||||
|
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||||
|
testSet.mapPartitions { testSamples =>
|
||||||
|
val sampleArray = testSamples.toList
|
||||||
|
val numRows = sampleArray.size
|
||||||
|
val numColumns = sampleArray.head.size
|
||||||
|
if (numRows == 0) {
|
||||||
|
Iterator()
|
||||||
|
} else {
|
||||||
|
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
|
||||||
|
Rabit.init(rabitEnv.asJava)
|
||||||
|
// translate to required format
|
||||||
|
val flatSampleArray = new Array[Float](numRows * numColumns)
|
||||||
|
for (i <- flatSampleArray.indices) {
|
||||||
|
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
||||||
|
}
|
||||||
|
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
||||||
|
Rabit.shutdown()
|
||||||
|
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict result with the given test set (represented as RDD)
|
* Predict result with the given test set (represented as RDD)
|
||||||
*
|
*
|
||||||
* @param testSet test set represented as RDD
|
* @param testSet test set represented as RDD
|
||||||
* @param useExternalCache whether to use external cache for the test set
|
* @param useExternalCache whether to use external cache for the test set
|
||||||
*/
|
*/
|
||||||
def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = {
|
def predict(testSet: RDD[MLVector], useExternalCache: Boolean = false):
|
||||||
import DataUtils._
|
RDD[Array[Array[Float]]] = {
|
||||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||||
val appName = testSet.context.appName
|
val appName = testSet.context.appName
|
||||||
testSet.mapPartitions { testSamples =>
|
testSet.mapPartitions { testSamples =>
|
||||||
if (testSamples.hasNext) {
|
if (testSamples.hasNext) {
|
||||||
|
import DataUtils._
|
||||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
val cacheFileName = {
|
val cacheFileName = {
|
||||||
@ -122,48 +177,76 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected def transformImpl(testSet: Dataset[_]): DataFrame
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict result with the given test set (represented as RDD)
|
* append leaf index of each row as an additional column in the original dataset
|
||||||
*
|
*
|
||||||
* @param testSet test set represented as RDD
|
* @return the original dataframe with an additional column containing prediction results
|
||||||
* @param missingValue the specified value to represent the missing value
|
|
||||||
*/
|
*/
|
||||||
def predict(testSet: RDD[DenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
|
def transformLeaf(testSet: Dataset[_]): DataFrame = {
|
||||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
val predictRDD = produceRowRDD(testSet, predLeaf = true)
|
||||||
testSet.mapPartitions { testSamples =>
|
setPredictionCol("predLeaf")
|
||||||
val sampleArray = testSamples.toList
|
transformSchema(testSet.schema, logging = true)
|
||||||
val numRows = sampleArray.size
|
testSet.sparkSession.createDataFrame(predictRDD, testSet.schema.add($(predictionCol),
|
||||||
val numColumns = sampleArray.head.size
|
ArrayType(FloatType, containsNull = false)))
|
||||||
if (numRows == 0) {
|
}
|
||||||
Iterator()
|
|
||||||
} else {
|
protected def produceRowRDD(testSet: Dataset[_], outputMargin: Boolean = false,
|
||||||
// translate to required format
|
predLeaf: Boolean = false): RDD[Row] = {
|
||||||
val flatSampleArray = new Array[Float](numRows * numColumns)
|
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
|
||||||
for (i <- flatSampleArray.indices) {
|
val appName = testSet.sparkSession.sparkContext.appName
|
||||||
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
|
testSet.rdd.mapPartitions {
|
||||||
|
rowIterator =>
|
||||||
|
if (rowIterator.hasNext) {
|
||||||
|
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||||
|
Rabit.init(rabitEnv.asJava)
|
||||||
|
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||||
|
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[MLVector](
|
||||||
|
$(featuresCol))).toList.iterator
|
||||||
|
import DataUtils._
|
||||||
|
val cachePrefix = {
|
||||||
|
if ($(useExternalMemory)) {
|
||||||
|
s"$appName-${TaskContext.get().stageId()}-dtest_cache-${TaskContext.getPartitionId()}"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val testDataset = new DMatrix(vectorIterator, cachePrefix)
|
||||||
|
val rawPredictResults = {
|
||||||
|
if (!predLeaf) {
|
||||||
|
broadcastBooster.value.predict(testDataset, outputMargin).
|
||||||
|
map(Row(_)).iterator
|
||||||
|
} else {
|
||||||
|
broadcastBooster.value.predictLeaf(testDataset).map(Row(_)).iterator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Rabit.shutdown()
|
||||||
|
// concatenate original data partition and predictions
|
||||||
|
rowItr1.zip(rawPredictResults).map {
|
||||||
|
case (originalColumns: Row, predictColumn: Row) =>
|
||||||
|
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Iterator[Row]()
|
||||||
}
|
}
|
||||||
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
|
|
||||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict leaf instances with the given test set (represented as RDD)
|
* produces the prediction results and append as an additional column in the original dataset
|
||||||
|
* NOTE: the prediction results is kept as the original format of xgboost
|
||||||
*
|
*
|
||||||
* @param testSet test set represented as RDD
|
* @return the original dataframe with an additional column containing prediction results
|
||||||
*/
|
*/
|
||||||
def predictLeaves(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
override def transform(testSet: Dataset[_]): DataFrame = {
|
||||||
import DataUtils._
|
transformImpl(testSet)
|
||||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
}
|
||||||
testSet.mapPartitions { testSamples =>
|
|
||||||
if (testSamples.hasNext) {
|
private def saveGeneralModelParam(outputStream: FSDataOutputStream): Unit = {
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
outputStream.writeUTF(getFeaturesCol)
|
||||||
Iterator(broadcastBooster.value.predictLeaf(dMatrix))
|
outputStream.writeUTF(getLabelCol)
|
||||||
} else {
|
outputStream.writeUTF(getPredictionCol)
|
||||||
Iterator()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -174,109 +257,34 @@ class XGBoostModel(_booster: Booster) extends Model[XGBoostModel] with Serializa
|
|||||||
def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = {
|
def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = {
|
||||||
val path = new Path(modelPath)
|
val path = new Path(modelPath)
|
||||||
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
|
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
|
||||||
|
// output model type
|
||||||
|
this match {
|
||||||
|
case model: XGBoostClassificationModel =>
|
||||||
|
outputStream.writeUTF("_cls_")
|
||||||
|
saveGeneralModelParam(outputStream)
|
||||||
|
outputStream.writeUTF(model.getRawPredictionCol)
|
||||||
|
// threshold
|
||||||
|
// threshold length
|
||||||
|
if (!isDefined(model.thresholds)) {
|
||||||
|
outputStream.writeInt(-1)
|
||||||
|
} else {
|
||||||
|
val thresholdLength = model.getThresholds.length
|
||||||
|
outputStream.writeInt(thresholdLength)
|
||||||
|
for (i <- 0 until thresholdLength) {
|
||||||
|
outputStream.writeDouble(model.getThresholds(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case model: XGBoostRegressionModel =>
|
||||||
|
outputStream.writeUTF("_reg_")
|
||||||
|
// eventual prediction col
|
||||||
|
saveGeneralModelParam(outputStream)
|
||||||
|
}
|
||||||
|
// booster
|
||||||
_booster.saveModel(outputStream)
|
_booster.saveModel(outputStream)
|
||||||
outputStream.close()
|
outputStream.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// override protected def featuresDataType: DataType = new VectorUDT
|
||||||
|
|
||||||
def booster: Booster = _booster
|
def booster: Booster = _booster
|
||||||
|
|
||||||
override val uid: String = Identifiable.randomUID("XGBoostModel")
|
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostModel = {
|
|
||||||
defaultCopy(extra)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* append leaf index of each row as an additional column in the original dataset
|
|
||||||
*
|
|
||||||
* @return the original dataframe with an additional column containing prediction results
|
|
||||||
*/
|
|
||||||
def transformLeaf(testSet: Dataset[_]): Unit = {
|
|
||||||
outputCol = "predLeaf"
|
|
||||||
transformSchema(testSet.schema, logging = true)
|
|
||||||
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
|
|
||||||
val instances = testSet.rdd.mapPartitions {
|
|
||||||
rowIterator =>
|
|
||||||
if (rowIterator.hasNext) {
|
|
||||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
|
||||||
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](inputCol)).
|
|
||||||
toList.iterator
|
|
||||||
import DataUtils._
|
|
||||||
val testDataset = new DMatrix(vectorIterator, null)
|
|
||||||
val rowPredictResults = broadcastBooster.value.predictLeaf(testDataset)
|
|
||||||
val predictResults = rowPredictResults.map(prediction => Row(prediction)).iterator
|
|
||||||
rowItr1.zip(predictResults).map {
|
|
||||||
case (originalColumns: Row, predictColumn: Row) =>
|
|
||||||
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Iterator[Row]()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
testSet.sparkSession.createDataFrame(instances, testSet.schema.add(outputCol, outputType)).
|
|
||||||
cache()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* produces the prediction results and append as an additional column in the original dataset
|
|
||||||
* NOTE: the prediction results is kept as the original format of xgboost
|
|
||||||
*
|
|
||||||
* @return the original dataframe with an additional column containing prediction results
|
|
||||||
*/
|
|
||||||
override def transform(testSet: Dataset[_]): DataFrame = {
|
|
||||||
transform(testSet, None)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* produces the prediction results and append as an additional column in the original dataset
|
|
||||||
* NOTE: the prediction results is transformed by applying the transformation function
|
|
||||||
* predictResultTrans to the original xgboost output
|
|
||||||
*
|
|
||||||
* @param rawPredictTransformer the function to transform xgboost output to the expected format
|
|
||||||
* @return the original dataframe with an additional column containing prediction results
|
|
||||||
*/
|
|
||||||
def transform(testSet: Dataset[_], rawPredictTransformer: Option[Array[Float] => DataType]):
|
|
||||||
DataFrame = {
|
|
||||||
transformSchema(testSet.schema, logging = true)
|
|
||||||
val broadcastBooster = testSet.sparkSession.sparkContext.broadcast(_booster)
|
|
||||||
val instances = testSet.rdd.mapPartitions {
|
|
||||||
rowIterator =>
|
|
||||||
if (rowIterator.hasNext) {
|
|
||||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
|
||||||
val vectorIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](inputCol)).
|
|
||||||
toList.iterator
|
|
||||||
import DataUtils._
|
|
||||||
val testDataset = new DMatrix(vectorIterator, null)
|
|
||||||
val rowPredictResults = broadcastBooster.value.predict(testDataset)
|
|
||||||
val predictResults = {
|
|
||||||
if (rawPredictTransformer.isDefined) {
|
|
||||||
rowPredictResults.map(prediction =>
|
|
||||||
Row(rawPredictTransformer.get(prediction))).iterator
|
|
||||||
} else {
|
|
||||||
rowPredictResults.map(prediction => Row(prediction)).iterator
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rowItr1.zip(predictResults).map {
|
|
||||||
case (originalColumns: Row, predictColumn: Row) =>
|
|
||||||
Row.fromSeq(originalColumns.toSeq ++ predictColumn.toSeq)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Iterator[Row]()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
testSet.sparkSession.createDataFrame(instances, testSet.schema.add(outputCol, outputType)).
|
|
||||||
cache()
|
|
||||||
}
|
|
||||||
|
|
||||||
@DeveloperApi
|
|
||||||
override def transformSchema(schema: StructType): StructType = {
|
|
||||||
if (schema.fieldNames.contains(outputCol)) {
|
|
||||||
throw new IllegalArgumentException(s"Output column $outputCol already exists.")
|
|
||||||
}
|
|
||||||
val inputType = schema(inputCol).dataType
|
|
||||||
require(inputType.equals(new VectorUDT),
|
|
||||||
s"the type of input column $inputCol has to be VectorUDT")
|
|
||||||
val outputFields = schema.fields :+ StructField(outputCol, outputType, nullable = false)
|
|
||||||
StructType(outputFields)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,48 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
|
import org.apache.spark.ml.linalg.{Vector => MLVector}
|
||||||
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.util.Identifiable
|
||||||
|
import org.apache.spark.sql._
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||||
|
|
||||||
|
class XGBoostRegressionModel private[spark](override val uid: String, _booster: Booster)
|
||||||
|
extends XGBoostModel(_booster) {
|
||||||
|
|
||||||
|
def this(_booster: Booster) = this(Identifiable.randomUID("XGBoostRegressionModel"), _booster)
|
||||||
|
|
||||||
|
override protected def transformImpl(testSet: Dataset[_]): DataFrame = {
|
||||||
|
transformSchema(testSet.schema, logging = true)
|
||||||
|
val predictRDD = produceRowRDD(testSet)
|
||||||
|
testSet.sparkSession.createDataFrame(predictRDD, schema =
|
||||||
|
StructType(testSet.schema.add(StructField($(predictionCol),
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false)))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def predict(features: MLVector): Double = {
|
||||||
|
throw new Exception("XGBoost does not support online prediction for now")
|
||||||
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): XGBoostRegressionModel = {
|
||||||
|
defaultCopy(extra)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -50,6 +50,8 @@ class EvalError extends EvalTrait {
|
|||||||
logger.error(ex)
|
logger.error(ex)
|
||||||
return -1f
|
return -1f
|
||||||
}
|
}
|
||||||
|
require(predicts.length == labels.length, s"predicts length ${predicts.length} has to be" +
|
||||||
|
s" equal with label length ${labels.length}")
|
||||||
val nrow: Int = predicts.length
|
val nrow: Int = predicts.length
|
||||||
for (i <- 0 until nrow) {
|
for (i <- 0 until nrow) {
|
||||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||||
|
|||||||
@ -17,20 +17,21 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||||
|
|
||||||
class SharedSparkContext extends FunSuite with BeforeAndAfter with Serializable {
|
trait SharedSparkContext extends FunSuite with BeforeAndAfterAll with Serializable {
|
||||||
|
|
||||||
@transient protected implicit var sc: SparkContext = null
|
@transient protected implicit var sc: SparkContext = null
|
||||||
|
|
||||||
before {
|
override def beforeAll() {
|
||||||
// build SparkContext
|
// build SparkContext
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
||||||
|
set("spark.driver.memory", "512m")
|
||||||
sc = new SparkContext(sparkConf)
|
sc = new SparkContext(sparkConf)
|
||||||
sc.setLogLevel("ERROR")
|
sc.setLogLevel("ERROR")
|
||||||
}
|
}
|
||||||
|
|
||||||
after {
|
override def afterAll() {
|
||||||
if (sc != null) {
|
if (sc != null) {
|
||||||
sc.stop()
|
sc.stop()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,17 +21,23 @@ import java.io.File
|
|||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.io.Source
|
import scala.io.Source
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, Vector => SparkVector}
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.ml.linalg.{DenseVector, Vector => SparkVector}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
trait Utils extends Serializable {
|
trait Utils extends Serializable {
|
||||||
protected val numWorkers = Runtime.getRuntime().availableProcessors()
|
protected val numWorkers = Runtime.getRuntime().availableProcessors()
|
||||||
|
|
||||||
|
protected var labeledPointsRDD: RDD[LabeledPoint] = null
|
||||||
|
|
||||||
|
protected def cleanExternalCache(prefix: String): Unit = {
|
||||||
|
val dir = new File(".")
|
||||||
|
for (file <- dir.listFiles() if file.getName.startsWith(prefix)) {
|
||||||
|
file.delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
protected def loadLabelPoints(filePath: String): List[LabeledPoint] = {
|
||||||
val file = Source.fromFile(new File(filePath))
|
val file = Source.fromFile(new File(filePath))
|
||||||
val sampleList = new ListBuffer[LabeledPoint]
|
val sampleList = new ListBuffer[LabeledPoint]
|
||||||
@ -41,6 +47,15 @@ trait Utils extends Serializable {
|
|||||||
sampleList.toList
|
sampleList.toList
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected def loadLabelAndVector(filePath: String): List[(Double, SparkVector)] = {
|
||||||
|
val file = Source.fromFile(new File(filePath))
|
||||||
|
val sampleList = new ListBuffer[(Double, SparkVector)]
|
||||||
|
for (sample <- file.getLines()) {
|
||||||
|
sampleList += fromSVMStringToLabelAndVector(sample)
|
||||||
|
}
|
||||||
|
sampleList.toList
|
||||||
|
}
|
||||||
|
|
||||||
protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = {
|
protected def fromSVMStringToLabelAndVector(line: String): (Double, SparkVector) = {
|
||||||
val labelAndFeatures = line.split(" ")
|
val labelAndFeatures = line.split(" ")
|
||||||
val label = labelAndFeatures(0).toDouble
|
val label = labelAndFeatures(0).toDouble
|
||||||
@ -59,7 +74,10 @@ trait Utils extends Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = {
|
protected def buildTrainingRDD(sparkContext: SparkContext): RDD[LabeledPoint] = {
|
||||||
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
|
if (labeledPointsRDD == null) {
|
||||||
sparkContext.parallelize(sampleList, numWorkers)
|
val sampleList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
|
||||||
|
labeledPointsRDD = sparkContext.parallelize(sampleList, numWorkers)
|
||||||
|
}
|
||||||
|
labeledPointsRDD
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||||
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
class XGBoostConfigureSuite extends FunSuite with Utils {
|
||||||
|
|
||||||
|
test("nthread configuration must be equal to spark.task.cpus") {
|
||||||
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
||||||
|
set("spark.task.cpus", "4")
|
||||||
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
|
customSparkContext.setLogLevel("ERROR")
|
||||||
|
// start another app
|
||||||
|
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "nthread" -> 6)
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
|
}
|
||||||
|
customSparkContext.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("kryoSerializer test") {
|
||||||
|
labeledPointsRDD = null
|
||||||
|
val eval = new EvalError()
|
||||||
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||||
|
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||||
|
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||||
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
|
customSparkContext.setLogLevel("ERROR")
|
||||||
|
val trainingRDD = buildTrainingRDD(customSparkContext)
|
||||||
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
|
testSetDMatrix) < 0.1)
|
||||||
|
customSparkContext.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -25,77 +25,27 @@ import scala.io.Source
|
|||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.mllib.linalg.VectorUDT
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}
|
|
||||||
|
|
||||||
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
class XGBoostDFSuite extends SharedSparkContext with Utils {
|
||||||
|
|
||||||
private def loadRow(filePath: String): List[Row] = {
|
private var trainingDF: DataFrame = null
|
||||||
val file = Source.fromFile(new File(filePath))
|
|
||||||
val rowList = new ListBuffer[Row]
|
private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None): DataFrame = {
|
||||||
for (rowLine <- file.getLines()) {
|
if (trainingDF == null) {
|
||||||
rowList += fromSVMStringToRow(rowLine)
|
val rowList = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile)
|
||||||
|
val labeledPointsRDD = sparkContext.getOrElse(sc).parallelize(rowList, numWorkers)
|
||||||
|
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
||||||
|
import sparkSession.implicits._
|
||||||
|
trainingDF = sparkSession.createDataset(labeledPointsRDD).toDF
|
||||||
}
|
}
|
||||||
rowList.toList
|
trainingDF
|
||||||
}
|
}
|
||||||
|
|
||||||
private def buildTrainingDataframe(sparkContext: Option[SparkContext] = None):
|
test("test consistency and order preservation of dataframe-based model") {
|
||||||
DataFrame = {
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
val rowList = loadRow(getClass.getResource("/agaricus.txt.train").getFile)
|
"objective" -> "binary:logistic")
|
||||||
val rowRDD = sparkContext.getOrElse(sc).parallelize(rowList, numWorkers)
|
|
||||||
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
|
||||||
sparkSession.createDataFrame(rowRDD,
|
|
||||||
StructType(Array(StructField("label", DoubleType, nullable = false),
|
|
||||||
StructField("features", new VectorUDT, nullable = false))))
|
|
||||||
}
|
|
||||||
|
|
||||||
private def fromSVMStringToRow(line: String): Row = {
|
|
||||||
val (label, sv) = fromSVMStringToLabelAndVector(line)
|
|
||||||
Row(label, sv)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test consistency between training with dataframe and RDD") {
|
|
||||||
val trainingDF = buildTrainingDataframe()
|
|
||||||
val trainingRDD = buildTrainingRDD(sc)
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
|
||||||
"objective" -> "binary:logistic").toMap
|
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
|
||||||
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
|
||||||
val xgBoostModelWithRDD = XGBoost.trainWithRDD(trainingRDD, paramMap,
|
|
||||||
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
|
||||||
val eval = new EvalError()
|
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
|
||||||
import DataUtils._
|
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
|
||||||
assert(
|
|
||||||
eval.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
|
|
||||||
testSetDMatrix) ===
|
|
||||||
eval.eval(xgBoostModelWithRDD.booster.predict(testSetDMatrix, outPutMargin = true),
|
|
||||||
testSetDMatrix))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test transform of dataframe-based model") {
|
|
||||||
val trainingDF = buildTrainingDataframe()
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
|
||||||
"objective" -> "binary:logistic").toMap
|
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
|
||||||
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile)
|
|
||||||
val testRowsRDD = sc.parallelize(testSet.zipWithIndex, numWorkers).map{
|
|
||||||
case (instance: LabeledPoint, id: Int) =>
|
|
||||||
Row(id, instance.features, instance.label)
|
|
||||||
}
|
|
||||||
val testDF = trainingDF.sparkSession.createDataFrame(testRowsRDD, StructType(
|
|
||||||
Array(StructField("id", IntegerType),
|
|
||||||
StructField("features", new VectorUDT), StructField("label", DoubleType))))
|
|
||||||
xgBoostModelWithDF.transform(testDF).show()
|
|
||||||
}
|
|
||||||
|
|
||||||
test("test order preservation of dataframe-based model") {
|
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
|
||||||
"objective" -> "binary:logistic").toMap
|
|
||||||
val trainingItr = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile).
|
val trainingItr = loadLabelPoints(getClass.getResource("/agaricus.txt.train").getFile).
|
||||||
iterator
|
iterator
|
||||||
val (testItr, auxTestItr) =
|
val (testItr, auxTestItr) =
|
||||||
@ -105,25 +55,109 @@ class XGBoostDFSuite extends SharedSparkContext with Utils {
|
|||||||
val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
val testDMatrix = new DMatrix(new JDMatrix(testItr, null))
|
||||||
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5)
|
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, 5)
|
||||||
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
||||||
val testRowsRDD = sc.parallelize(
|
val testSetItr = auxTestItr.zipWithIndex.map {
|
||||||
auxTestItr.toList.zipWithIndex, numWorkers).map {
|
|
||||||
case (instance: LabeledPoint, id: Int) =>
|
case (instance: LabeledPoint, id: Int) =>
|
||||||
Row(id, instance.features, instance.label)
|
(id, instance.features, instance.label)
|
||||||
}
|
}
|
||||||
val trainingDF = buildTrainingDataframe()
|
val trainingDF = buildTrainingDataframe()
|
||||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
||||||
val testDF = trainingDF.sqlContext.createDataFrame(testRowsRDD, StructType(
|
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
|
||||||
Array(StructField("id", IntegerType), StructField("features", new VectorUDT),
|
"id", "features", "label")
|
||||||
StructField("label", DoubleType))))
|
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||||
val predResultsFromDF =
|
collect().map(row =>
|
||||||
xgBoostModelWithDF.transform(testDF).collect().map(row => (row.getAs[Int]("id"),
|
(row.getAs[Int]("id"), row.getAs[mutable.WrappedArray[Float]]("probabilities"))
|
||||||
row.getAs[mutable.WrappedArray[Float]]("prediction"))).toMap
|
).toMap
|
||||||
|
assert(testDF.count() === predResultsFromDF.size)
|
||||||
for (i <- predResultFromSeq.indices) {
|
for (i <- predResultFromSeq.indices) {
|
||||||
assert(predResultFromSeq(i).length === predResultsFromDF(i).length)
|
assert(predResultFromSeq(i).length === predResultsFromDF(i).length)
|
||||||
for (j <- predResultFromSeq(i).indices) {
|
for (j <- predResultFromSeq(i).indices) {
|
||||||
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j))
|
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
cleanExternalCache("XGBoostDFSuite")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test transformLeaf") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
val trainingDF = buildTrainingDataframe()
|
||||||
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
|
round = 5, nWorkers = numWorkers, useExternalMemory = false)
|
||||||
|
val testSetItr = testItr.zipWithIndex.map {
|
||||||
|
case (instance: LabeledPoint, id: Int) =>
|
||||||
|
(id, instance.features, instance.label)
|
||||||
|
}
|
||||||
|
val testDF = trainingDF.sparkSession.createDataFrame(testSetItr.toList).toDF(
|
||||||
|
"id", "features", "label")
|
||||||
|
xgBoostModelWithDF.transformLeaf(testDF).show()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test schema of XGBoostRegressionModel") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear")
|
||||||
|
val testItr = loadLabelPoints(getClass.getResource("/machine.txt.test").getFile).iterator.
|
||||||
|
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
|
||||||
|
(id, instance.features, instance.label)
|
||||||
|
}
|
||||||
|
val trainingDF = {
|
||||||
|
val rowList = loadLabelPoints(getClass.getResource("/machine.txt.train").getFile)
|
||||||
|
val labeledPointsRDD = sc.parallelize(rowList, numWorkers)
|
||||||
|
val sparkSession = SparkSession.builder().appName("XGBoostDFSuite").getOrCreate()
|
||||||
|
import sparkSession.implicits._
|
||||||
|
sparkSession.createDataset(labeledPointsRDD).toDF
|
||||||
|
}
|
||||||
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
|
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
||||||
|
xgBoostModelWithDF.setPredictionCol("final_prediction")
|
||||||
|
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
|
||||||
|
"id", "features", "label")
|
||||||
|
val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
||||||
|
assert(predictionDF.columns.contains("id") === true)
|
||||||
|
assert(predictionDF.columns.contains("features") === true)
|
||||||
|
assert(predictionDF.columns.contains("label") === true)
|
||||||
|
assert(predictionDF.columns.contains("final_prediction") === true)
|
||||||
|
predictionDF.show()
|
||||||
|
cleanExternalCache("XGBoostDFSuite")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test schema of XGBoostClassificationModel") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
val testItr = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator.
|
||||||
|
zipWithIndex.map { case (instance: LabeledPoint, id: Int) =>
|
||||||
|
(id, instance.features, instance.label)
|
||||||
|
}
|
||||||
|
val trainingDF = buildTrainingDataframe()
|
||||||
|
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||||
|
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
||||||
|
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(
|
||||||
|
"raw_prediction").setPredictionCol("final_prediction")
|
||||||
|
val testDF = trainingDF.sparkSession.createDataFrame(testItr.toList).toDF(
|
||||||
|
"id", "features", "label")
|
||||||
|
var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
||||||
|
assert(predictionDF.columns.contains("id") === true)
|
||||||
|
assert(predictionDF.columns.contains("features") === true)
|
||||||
|
assert(predictionDF.columns.contains("label") === true)
|
||||||
|
assert(predictionDF.columns.contains("raw_prediction") === true)
|
||||||
|
assert(predictionDF.columns.contains("final_prediction") === true)
|
||||||
|
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("").
|
||||||
|
setPredictionCol("final_prediction")
|
||||||
|
predictionDF = xgBoostModelWithDF.transform(testDF)
|
||||||
|
assert(predictionDF.columns.contains("id") === true)
|
||||||
|
assert(predictionDF.columns.contains("features") === true)
|
||||||
|
assert(predictionDF.columns.contains("label") === true)
|
||||||
|
assert(predictionDF.columns.contains("raw_prediction") === false)
|
||||||
|
assert(predictionDF.columns.contains("final_prediction") === true)
|
||||||
|
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].
|
||||||
|
setRawPredictionCol("raw_prediction").setPredictionCol("")
|
||||||
|
predictionDF = xgBoostModelWithDF.transform(testDF)
|
||||||
|
assert(predictionDF.columns.contains("id") === true)
|
||||||
|
assert(predictionDF.columns.contains("features") === true)
|
||||||
|
assert(predictionDF.columns.contains("label") === true)
|
||||||
|
assert(predictionDF.columns.contains("raw_prediction") === true)
|
||||||
|
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||||
|
cleanExternalCache("XGBoostDFSuite")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,66 +16,47 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import java.io.File
|
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
|
|
||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => ScalaXGBoost}
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import org.apache.spark.mllib.linalg.{Vector => SparkVector, Vectors}
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
|
import org.apache.spark.ml.linalg.{Vector => SparkVector, Vectors}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
|
||||||
|
|
||||||
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
||||||
|
|
||||||
test("build RDD containing boosters with the specified worker number") {
|
test("build RDD containing boosters with the specified worker number") {
|
||||||
val trainingRDD = buildTrainingRDD(sc)
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
|
||||||
import DataUtils._
|
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
|
||||||
val boosterRDD = XGBoost.buildDistributedBoosters(
|
val boosterRDD = XGBoost.buildDistributedBoosters(
|
||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
new scala.collection.mutable.HashMap[String, String],
|
new scala.collection.mutable.HashMap[String, String],
|
||||||
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = false)
|
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === 2)
|
assert(boosterCount === 2)
|
||||||
val boosters = boosterRDD.collect()
|
cleanExternalCache("XGBoostSuite")
|
||||||
val eval = new EvalError()
|
|
||||||
for (booster <- boosters) {
|
|
||||||
// the threshold is 0.11 because it does not sync boosters with AllReduce
|
|
||||||
val predicts = booster.predict(testSetDMatrix, outPutMargin = true)
|
|
||||||
assert(eval.eval(predicts, testSetDMatrix) < 0.11)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("training with external memory cache") {
|
test("training with external memory cache") {
|
||||||
sc.stop()
|
|
||||||
sc = null
|
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
|
||||||
customSparkContext.setLogLevel("ERROR")
|
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
nWorkers = numWorkers, useExternalMemory = true)
|
nWorkers = numWorkers, useExternalMemory = true)
|
||||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix) < 0.1)
|
testSetDMatrix) < 0.1)
|
||||||
customSparkContext.stop()
|
|
||||||
// clean
|
// clean
|
||||||
val dir = new File(".")
|
cleanExternalCache("XGBoostSuite")
|
||||||
for (file <- dir.listFiles() if file.getName.startsWith("XGBoostSuite-0-dtrain_cache")) {
|
|
||||||
file.delete()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test with dense vectors containing missing value") {
|
test("test with dense vectors containing missing value") {
|
||||||
@ -106,10 +87,13 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
}
|
}
|
||||||
val trainingRDD = buildDenseRDD().repartition(4)
|
val trainingRDD = buildDenseRDD().repartition(4)
|
||||||
val testRDD = buildDenseRDD().repartition(4)
|
val testRDD = buildDenseRDD().repartition(4)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers,
|
||||||
|
useExternalMemory = true)
|
||||||
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
|
xgBoostModel.predict(testRDD.map(_.features.toDense), missingValue = -0.1f).collect()
|
||||||
|
// clean
|
||||||
|
cleanExternalCache("XGBoostSuite")
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test consistency of prediction functions with RDD") {
|
test("test consistency of prediction functions with RDD") {
|
||||||
@ -120,11 +104,12 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
for (i <- testSet.indices) {
|
for (i <- testSet.indices) {
|
||||||
assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values))
|
assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values))
|
||||||
}
|
}
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic")
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
val predRDD = xgBoostModel.predict(testRDD)
|
val predRDD = xgBoostModel.predict(testRDD)
|
||||||
val predResult1 = predRDD.collect()(0)
|
val predResult1 = predRDD.collect()(0)
|
||||||
|
assert(testRDD.count() === predResult1.length)
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
|
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
|
||||||
for (i <- predResult1.indices; j <- predResult1(i).indices) {
|
for (i <- predResult1.indices; j <- predResult1(i).indices) {
|
||||||
@ -134,9 +119,9 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
|
|
||||||
test("test eval functions with RDD") {
|
test("test eval functions with RDD") {
|
||||||
val trainingRDD = buildTrainingRDD(sc).cache()
|
val trainingRDD = buildTrainingRDD(sc).cache()
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic")
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers)
|
||||||
xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
|
xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
|
||||||
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false)
|
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false)
|
||||||
}
|
}
|
||||||
@ -150,7 +135,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val testRDD = buildEmptyRDD()
|
val testRDD = buildEmptyRDD()
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
||||||
@ -164,8 +149,8 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic")
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||||
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||||
testSetDMatrix)
|
testSetDMatrix)
|
||||||
@ -177,41 +162,40 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
|
|||||||
assert(loadedEvalResults == evalResults)
|
assert(loadedEvalResults == evalResults)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("nthread configuration must be equal to spark.task.cpus") {
|
test("test save and load of different types of models") {
|
||||||
sc.stop()
|
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||||
sc = null
|
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite").
|
val trainingRDD = buildTrainingRDD(sc)
|
||||||
set("spark.task.cpus", "4")
|
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
"objective" -> "reg:linear")
|
||||||
customSparkContext.setLogLevel("ERROR")
|
// validate regression model
|
||||||
// start another app
|
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
nWorkers = numWorkers, useExternalMemory = false)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
xgBoostModel.setFeaturesCol("feature_col")
|
||||||
"objective" -> "binary:logistic", "nthread" -> 6).toMap
|
xgBoostModel.setLabelCol("label_col")
|
||||||
intercept[IllegalArgumentException] {
|
xgBoostModel.setPredictionCol("prediction_col")
|
||||||
XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
}
|
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
customSparkContext.stop()
|
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
|
||||||
}
|
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
|
||||||
|
assert(loadedXGBoostModel.getLabelCol == "label_col")
|
||||||
test("kryoSerializer test") {
|
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
|
||||||
sc.stop()
|
// classification model
|
||||||
sc = null
|
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
val eval = new EvalError()
|
"objective" -> "binary:logistic")
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
nWorkers = numWorkers, useExternalMemory = false)
|
||||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||||
val customSparkContext = new SparkContext(sparkConf)
|
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
|
||||||
customSparkContext.setLogLevel("ERROR")
|
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
val trainingRDD = buildTrainingRDD(customSparkContext)
|
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||||
val testSet = loadLabelPoints(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||||
import DataUtils._
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||||
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
"raw_col")
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||||
"objective" -> "binary:logistic").toMap
|
Array(0.5, 0.5).deep)
|
||||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||||
testSetDMatrix) < 0.1)
|
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||||
customSparkContext.stop()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -38,6 +38,8 @@ trait EvalTrait extends IEvaluation {
|
|||||||
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||||
|
|
||||||
private[scala] def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
|
private[scala] def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
|
||||||
|
require(predicts.length == jdmat.getLabel.length, "predicts size and label size must match " +
|
||||||
|
s" predicts size: ${predicts.length}, label size: ${jdmat.getLabel.length}")
|
||||||
eval(predicts, new DMatrix(jdmat))
|
eval(predicts, new DMatrix(jdmat))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user