[jvm-packages] enable predictLeaf/predictContrib/treeLimit in 0.8 (#3532)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * partial finish * no test * add test cases * add test cases * address comments * add test for regressor * fix typo
This commit is contained in:
parent
246ec92163
commit
1c08b3b2ea
@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.classification._
|
||||
import org.apache.spark.ml.linalg._
|
||||
@ -39,8 +39,11 @@ import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
|
||||
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
||||
with HasLeafPredictionCol with HasContribPredictionCol
|
||||
|
||||
class XGBoostClassifier (
|
||||
override val uid: String,
|
||||
@ -235,6 +238,12 @@ class XGBoostClassificationModel private[ml](
|
||||
this
|
||||
}
|
||||
|
||||
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
|
||||
|
||||
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
|
||||
|
||||
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
||||
|
||||
/**
|
||||
* Single instance prediction.
|
||||
* Note: The performance is not ideal, use it carefully!
|
||||
@ -287,22 +296,15 @@ class XGBoostClassificationModel private[ml](
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
val dm = new DMatrix(
|
||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
cacheInfo)
|
||||
try {
|
||||
val rawPredictionItr = {
|
||||
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator
|
||||
}
|
||||
val probabilityItr = {
|
||||
bBooster.value.predict(dm, outPutMargin = false).map(Row(_)).iterator
|
||||
}
|
||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||
producePredictionItrs(bBooster, dm)
|
||||
Rabit.shutdown()
|
||||
rowItr1.zip(rawPredictionItr).zip(probabilityItr).map {
|
||||
case ((originals: Row, rawPrediction: Row), probability: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
|
||||
}
|
||||
produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr,
|
||||
predContribItr)
|
||||
} finally {
|
||||
dm.delete()
|
||||
}
|
||||
@ -313,7 +315,82 @@ class XGBoostClassificationModel private[ml](
|
||||
|
||||
bBooster.unpersist(blocking = false)
|
||||
|
||||
dataset.sparkSession.createDataFrame(rdd, schema)
|
||||
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
|
||||
}
|
||||
|
||||
private def produceResultIterator(
|
||||
originalRowItr: Iterator[Row],
|
||||
rawPredictionItr: Iterator[Row],
|
||||
probabilityItr: Iterator[Row],
|
||||
predLeafItr: Iterator[Row],
|
||||
predContribItr: Iterator[Row]): Iterator[Row] = {
|
||||
// the following implementation is to be improved
|
||||
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
||||
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
||||
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
|
||||
map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
|
||||
contribs: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
|
||||
contribs.toSeq)
|
||||
}
|
||||
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
||||
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
|
||||
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
|
||||
map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq)
|
||||
}
|
||||
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
|
||||
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
||||
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr).
|
||||
map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq)
|
||||
}
|
||||
} else {
|
||||
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map {
|
||||
case ((originals: Row, rawPrediction: Row), probability: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def generateResultSchema(fixedSchema: StructType): StructType = {
|
||||
var resultSchema = fixedSchema
|
||||
if (isDefined(leafPredictionCol)) {
|
||||
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||
}
|
||||
if (isDefined(contribPredictionCol)) {
|
||||
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||
}
|
||||
resultSchema
|
||||
}
|
||||
|
||||
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
||||
Array[Iterator[Row]] = {
|
||||
val rawPredictionItr = {
|
||||
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
|
||||
map(Row(_)).iterator
|
||||
}
|
||||
val probabilityItr = {
|
||||
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).
|
||||
map(Row(_)).iterator
|
||||
}
|
||||
val predLeafItr = {
|
||||
if (isDefined(leafPredictionCol)) {
|
||||
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
val predContribItr = {
|
||||
if (isDefined(contribPredictionCol)) {
|
||||
broadcastBooster.value.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
@ -329,11 +406,11 @@ class XGBoostClassificationModel private[ml](
|
||||
var outputData = transformInternal(dataset)
|
||||
var numColsOutput = 0
|
||||
|
||||
val rawPredictionUDF = udf { (rawPrediction: mutable.WrappedArray[Float]) =>
|
||||
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
||||
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
|
||||
}
|
||||
|
||||
val probabilityUDF = udf { (probability: mutable.WrappedArray[Float]) =>
|
||||
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||
if (numClasses == 2) {
|
||||
Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble))
|
||||
} else {
|
||||
@ -341,7 +418,7 @@ class XGBoostClassificationModel private[ml](
|
||||
}
|
||||
}
|
||||
|
||||
val predictUDF = udf { (probability: mutable.WrappedArray[Float]) =>
|
||||
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||
// From XGBoost probability to MLlib prediction
|
||||
val probabilities = if (numClasses == 2) {
|
||||
Array(1 - probability(0), probability(0)).map(_.toDouble)
|
||||
|
||||
@ -24,8 +24,8 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
@ -37,12 +37,13 @@ import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.json4s.DefaultFormats
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
|
||||
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
|
||||
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
|
||||
with ParamMapFuncs
|
||||
with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol
|
||||
|
||||
class XGBoostRegressor (
|
||||
override val uid: String,
|
||||
@ -231,6 +232,12 @@ class XGBoostRegressionModel private[ml] (
|
||||
this
|
||||
}
|
||||
|
||||
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
|
||||
|
||||
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
|
||||
|
||||
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
||||
|
||||
/**
|
||||
* Single instance prediction.
|
||||
* Note: The performance is not ideal, use it carefully!
|
||||
@ -270,14 +277,10 @@ class XGBoostRegressionModel private[ml] (
|
||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
cacheInfo)
|
||||
try {
|
||||
val originalPredictionItr = {
|
||||
bBooster.value.predict(dm).map(Row(_)).iterator
|
||||
}
|
||||
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
||||
producePredictionItrs(bBooster, dm)
|
||||
Rabit.shutdown()
|
||||
rowItr1.zip(originalPredictionItr).map {
|
||||
case (originals: Row, originalPrediction: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
|
||||
}
|
||||
produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr)
|
||||
} finally {
|
||||
dm.delete()
|
||||
}
|
||||
@ -285,10 +288,77 @@ class XGBoostRegressionModel private[ml] (
|
||||
Iterator[Row]()
|
||||
}
|
||||
}
|
||||
|
||||
bBooster.unpersist(blocking = false)
|
||||
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
|
||||
}
|
||||
|
||||
dataset.sparkSession.createDataFrame(rdd, schema)
|
||||
private def produceResultIterator(
|
||||
originalRowItr: Iterator[Row],
|
||||
predictionItr: Iterator[Row],
|
||||
predLeafItr: Iterator[Row],
|
||||
predContribItr: Iterator[Row]): Iterator[Row] = {
|
||||
// the following implementation is to be improved
|
||||
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
||||
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
||||
originalRowItr.zip(predictionItr).zip(predLeafItr).zip(predContribItr).
|
||||
map { case (((originals: Row, prediction: Row), leaves: Row), contribs: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq ++ contribs.toSeq)
|
||||
}
|
||||
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
||||
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
|
||||
originalRowItr.zip(predictionItr).zip(predLeafItr).
|
||||
map { case ((originals: Row, prediction: Row), leaves: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq)
|
||||
}
|
||||
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
|
||||
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
||||
originalRowItr.zip(predictionItr).zip(predContribItr).
|
||||
map { case ((originals: Row, prediction: Row), contribs: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ contribs.toSeq)
|
||||
}
|
||||
} else {
|
||||
originalRowItr.zip(predictionItr).map {
|
||||
case (originals: Row, originalPrediction: Row) =>
|
||||
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def generateResultSchema(fixedSchema: StructType): StructType = {
|
||||
var resultSchema = fixedSchema
|
||||
if (isDefined(leafPredictionCol)) {
|
||||
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||
}
|
||||
if (isDefined(contribPredictionCol)) {
|
||||
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
|
||||
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||
}
|
||||
resultSchema
|
||||
}
|
||||
|
||||
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
||||
Array[Iterator[Row]] = {
|
||||
val originalPredictionItr = {
|
||||
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
|
||||
}
|
||||
val predLeafItr = {
|
||||
if (isDefined(leafPredictionCol)) {
|
||||
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).
|
||||
map(Row(_)).iterator
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
val predContribItr = {
|
||||
if (isDefined(contribPredictionCol)) {
|
||||
broadcastBooster.value.predictContrib(dm, $(treeLimit)).
|
||||
map(Row(_)).iterator
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
}
|
||||
Array(originalPredictionItr, predLeafItr, predContribItr)
|
||||
}
|
||||
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
|
||||
@ -237,13 +237,18 @@ private[spark] trait BoosterParams extends Params {
|
||||
|
||||
final def getLambdaBias: Double = $(lambdaBias)
|
||||
|
||||
final val treeLimit = new IntParam(this, name = "treeLimit",
|
||||
doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
|
||||
|
||||
final def getTreeLimit: Double = $(lambdaBias)
|
||||
|
||||
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||
growPolicy -> "depthwise", maxBins -> 16,
|
||||
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
|
||||
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
||||
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
|
||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
|
||||
}
|
||||
|
||||
private[spark] object BoosterParams {
|
||||
|
||||
@ -158,6 +158,30 @@ private[spark] trait GeneralParams extends Params {
|
||||
)
|
||||
}
|
||||
|
||||
trait HasLeafPredictionCol extends Params {
|
||||
/**
|
||||
* Param for leaf prediction column name.
|
||||
* @group param
|
||||
*/
|
||||
final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
|
||||
"name of the predictLeaf results")
|
||||
|
||||
/** @group getParam */
|
||||
final def getLeafPredictionCol: String = $(leafPredictionCol)
|
||||
}
|
||||
|
||||
trait HasContribPredictionCol extends Params {
|
||||
/**
|
||||
* Param for contribution prediction column name.
|
||||
* @group param
|
||||
*/
|
||||
final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
|
||||
"name of the predictContrib results")
|
||||
|
||||
/** @group getParam */
|
||||
final def getContribPredictionCol: String = $(contribPredictionCol)
|
||||
}
|
||||
|
||||
trait HasBaseMarginCol extends Params {
|
||||
|
||||
/**
|
||||
|
||||
@ -211,4 +211,77 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
}
|
||||
|
||||
test("test predictionLeaf") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val test = buildDataFrame(Classification.test)
|
||||
val groundTruth = test.count()
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setLeafPredictionCol("predictLeaf")
|
||||
val resultDF = model.transform(test)
|
||||
assert(resultDF.count == groundTruth)
|
||||
assert(resultDF.columns.contains("predictLeaf"))
|
||||
}
|
||||
|
||||
test("test predictionLeaf with empty column name") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val test = buildDataFrame(Classification.test)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setLeafPredictionCol("")
|
||||
val resultDF = model.transform(test)
|
||||
assert(!resultDF.columns.contains("predictLeaf"))
|
||||
}
|
||||
|
||||
test("test predictionContrib") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val test = buildDataFrame(Classification.test)
|
||||
val groundTruth = test.count()
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setContribPredictionCol("predictContrib")
|
||||
val resultDF = model.transform(buildDataFrame(Classification.test))
|
||||
assert(resultDF.count == groundTruth)
|
||||
assert(resultDF.columns.contains("predictContrib"))
|
||||
}
|
||||
|
||||
test("test predictionContrib with empty column name") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val test = buildDataFrame(Classification.test)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setContribPredictionCol("")
|
||||
val resultDF = model.transform(test)
|
||||
assert(!resultDF.columns.contains("predictContrib"))
|
||||
}
|
||||
|
||||
test("test predictionLeaf and predictionContrib") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val test = buildDataFrame(Classification.test)
|
||||
val groundTruth = test.count()
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setLeafPredictionCol("predictLeaf")
|
||||
model.setContribPredictionCol("predictContrib")
|
||||
val resultDF = model.transform(buildDataFrame(Classification.test))
|
||||
assert(resultDF.count == groundTruth)
|
||||
assert(resultDF.columns.contains("predictLeaf"))
|
||||
assert(resultDF.columns.contains("predictContrib"))
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,4 +120,72 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
val first = prediction.head.getAs[Double]("prediction")
|
||||
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
|
||||
}
|
||||
|
||||
test("test predictionLeaf") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val groundTruth = testDF.count()
|
||||
val xgb = new XGBoostRegressor(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setLeafPredictionCol("predictLeaf")
|
||||
val resultDF = model.transform(testDF)
|
||||
assert(resultDF.count === groundTruth)
|
||||
assert(resultDF.columns.contains("predictLeaf"))
|
||||
}
|
||||
|
||||
test("test predictionLeaf with empty column name") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val xgb = new XGBoostRegressor(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setLeafPredictionCol("")
|
||||
val resultDF = model.transform(testDF)
|
||||
assert(!resultDF.columns.contains("predictLeaf"))
|
||||
}
|
||||
|
||||
test("test predictionContrib") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val groundTruth = testDF.count()
|
||||
val xgb = new XGBoostRegressor(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setContribPredictionCol("predictContrib")
|
||||
val resultDF = model.transform(testDF)
|
||||
assert(resultDF.count === groundTruth)
|
||||
assert(resultDF.columns.contains("predictContrib"))
|
||||
}
|
||||
|
||||
test("test predictionContrib with empty column name") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val xgb = new XGBoostRegressor(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setContribPredictionCol("")
|
||||
val resultDF = model.transform(testDF)
|
||||
assert(!resultDF.columns.contains("predictContrib"))
|
||||
}
|
||||
|
||||
test("test predictionLeaf and predictionContrib") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val groundTruth = testDF.count()
|
||||
val xgb = new XGBoostRegressor(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
model.setLeafPredictionCol("predictLeaf")
|
||||
model.setContribPredictionCol("predictContrib")
|
||||
val resultDF = model.transform(testDF)
|
||||
assert(resultDF.count === groundTruth)
|
||||
assert(resultDF.columns.contains("predictLeaf"))
|
||||
assert(resultDF.columns.contains("predictContrib"))
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,8 +125,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0)
|
||||
: Array[Array[Float]] = {
|
||||
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0):
|
||||
Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
@ -139,7 +139,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
|
||||
def predictLeaf(data: DMatrix, treeLimit: Int = 0): Array[Array[Float]] = {
|
||||
booster.predictLeaf(data.jDMatrix, treeLimit)
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user