[jvm-packages] enable predictLeaf/predictContrib/treeLimit in 0.8 (#3532)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* partial finish

* no test

* add test cases

* add test cases

* address comments

* add test for regressor

* fix typo
This commit is contained in:
Nan Zhu 2018-08-07 14:01:18 -07:00 committed by GitHub
parent 246ec92163
commit 1c08b3b2ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 349 additions and 32 deletions

View File

@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.params._ import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.apache.spark.ml.classification._ import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._ import org.apache.spark.ml.linalg._
@ -39,8 +39,11 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.json4s.DefaultFormats import org.json4s.DefaultFormats
import org.apache.spark.broadcast.Broadcast
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
with HasLeafPredictionCol with HasContribPredictionCol
class XGBoostClassifier ( class XGBoostClassifier (
override val uid: String, override val uid: String,
@ -235,6 +238,12 @@ class XGBoostClassificationModel private[ml](
this this
} }
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
/** /**
* Single instance prediction. * Single instance prediction.
* Note: The performance is not ideal, use it carefully! * Note: The performance is not ideal, use it carefully!
@ -287,22 +296,15 @@ class XGBoostClassificationModel private[ml](
null null
} }
} }
val dm = new DMatrix( val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo) cacheInfo)
try { try {
val rawPredictionItr = { val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator producePredictionItrs(bBooster, dm)
}
val probabilityItr = {
bBooster.value.predict(dm, outPutMargin = false).map(Row(_)).iterator
}
Rabit.shutdown() Rabit.shutdown()
rowItr1.zip(rawPredictionItr).zip(probabilityItr).map { produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr,
case ((originals: Row, rawPrediction: Row), probability: Row) => predContribItr)
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
}
} finally { } finally {
dm.delete() dm.delete()
} }
@ -313,7 +315,82 @@ class XGBoostClassificationModel private[ml](
bBooster.unpersist(blocking = false) bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rdd, schema) dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
}
private def produceResultIterator(
originalRowItr: Iterator[Row],
rawPredictionItr: Iterator[Row],
probabilityItr: Iterator[Row],
predLeafItr: Iterator[Row],
predContribItr: Iterator[Row]): Iterator[Row] = {
// the following implementation is to be improved
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
contribs: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
contribs.toSeq)
}
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq)
}
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr).
map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq)
}
} else {
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map {
case ((originals: Row, rawPrediction: Row), probability: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
}
}
}
private def generateResultSchema(fixedSchema: StructType): StructType = {
var resultSchema = fixedSchema
if (isDefined(leafPredictionCol)) {
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (isDefined(contribPredictionCol)) {
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
resultSchema
}
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
Array[Iterator[Row]] = {
val rawPredictionItr = {
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
map(Row(_)).iterator
}
val probabilityItr = {
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).
map(Row(_)).iterator
}
val predLeafItr = {
if (isDefined(leafPredictionCol)) {
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
} else {
Iterator()
}
}
val predContribItr = {
if (isDefined(contribPredictionCol)) {
broadcastBooster.value.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
} else {
Iterator()
}
}
Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
} }
override def transform(dataset: Dataset[_]): DataFrame = { override def transform(dataset: Dataset[_]): DataFrame = {
@ -329,11 +406,11 @@ class XGBoostClassificationModel private[ml](
var outputData = transformInternal(dataset) var outputData = transformInternal(dataset)
var numColsOutput = 0 var numColsOutput = 0
val rawPredictionUDF = udf { (rawPrediction: mutable.WrappedArray[Float]) => val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
Vectors.dense(rawPrediction.map(_.toDouble).toArray) Vectors.dense(rawPrediction.map(_.toDouble).toArray)
} }
val probabilityUDF = udf { (probability: mutable.WrappedArray[Float]) => val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
if (numClasses == 2) { if (numClasses == 2) {
Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble)) Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble))
} else { } else {
@ -341,7 +418,7 @@ class XGBoostClassificationModel private[ml](
} }
} }
val predictUDF = udf { (probability: mutable.WrappedArray[Float]) => val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
// From XGBoost probability to MLlib prediction // From XGBoost probability to MLlib prediction
val probabilities = if (numClasses == 2) { val probabilities = if (numClasses == 2) {
Array(1 - probability(0), probability(0)).map(_.toDouble) Array(1 - probability(0), probability(0)).map(_.toDouble)

View File

@ -24,8 +24,8 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.param.shared.HasWeightCol
@ -37,12 +37,13 @@ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.json4s.DefaultFormats import org.json4s.DefaultFormats
import scala.collection.mutable import scala.collection.mutable
import org.apache.spark.broadcast.Broadcast
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
with ParamMapFuncs with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol
class XGBoostRegressor ( class XGBoostRegressor (
override val uid: String, override val uid: String,
@ -231,6 +232,12 @@ class XGBoostRegressionModel private[ml] (
this this
} }
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
/** /**
* Single instance prediction. * Single instance prediction.
* Note: The performance is not ideal, use it carefully! * Note: The performance is not ideal, use it carefully!
@ -270,14 +277,10 @@ class XGBoostRegressionModel private[ml] (
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo) cacheInfo)
try { try {
val originalPredictionItr = { val Array(originalPredictionItr, predLeafItr, predContribItr) =
bBooster.value.predict(dm).map(Row(_)).iterator producePredictionItrs(bBooster, dm)
}
Rabit.shutdown() Rabit.shutdown()
rowItr1.zip(originalPredictionItr).map { produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr)
case (originals: Row, originalPrediction: Row) =>
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
}
} finally { } finally {
dm.delete() dm.delete()
} }
@ -285,10 +288,77 @@ class XGBoostRegressionModel private[ml] (
Iterator[Row]() Iterator[Row]()
} }
} }
bBooster.unpersist(blocking = false) bBooster.unpersist(blocking = false)
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
}
dataset.sparkSession.createDataFrame(rdd, schema) private def produceResultIterator(
originalRowItr: Iterator[Row],
predictionItr: Iterator[Row],
predLeafItr: Iterator[Row],
predContribItr: Iterator[Row]): Iterator[Row] = {
// the following implementation is to be improved
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
originalRowItr.zip(predictionItr).zip(predLeafItr).zip(predContribItr).
map { case (((originals: Row, prediction: Row), leaves: Row), contribs: Row) =>
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq ++ contribs.toSeq)
}
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
originalRowItr.zip(predictionItr).zip(predLeafItr).
map { case ((originals: Row, prediction: Row), leaves: Row) =>
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq)
}
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
originalRowItr.zip(predictionItr).zip(predContribItr).
map { case ((originals: Row, prediction: Row), contribs: Row) =>
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ contribs.toSeq)
}
} else {
originalRowItr.zip(predictionItr).map {
case (originals: Row, originalPrediction: Row) =>
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
}
}
}
private def generateResultSchema(fixedSchema: StructType): StructType = {
var resultSchema = fixedSchema
if (isDefined(leafPredictionCol)) {
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
if (isDefined(contribPredictionCol)) {
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
resultSchema
}
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
Array[Iterator[Row]] = {
val originalPredictionItr = {
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
}
val predLeafItr = {
if (isDefined(leafPredictionCol)) {
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).
map(Row(_)).iterator
} else {
Iterator()
}
}
val predContribItr = {
if (isDefined(contribPredictionCol)) {
broadcastBooster.value.predictContrib(dm, $(treeLimit)).
map(Row(_)).iterator
} else {
Iterator()
}
}
Array(originalPredictionItr, predLeafItr, predContribItr)
} }
override def transform(dataset: Dataset[_]): DataFrame = { override def transform(dataset: Dataset[_]): DataFrame = {

View File

@ -237,13 +237,18 @@ private[spark] trait BoosterParams extends Params {
final def getLambdaBias: Double = $(lambdaBias) final def getLambdaBias: Double = $(lambdaBias)
final val treeLimit = new IntParam(this, name = "treeLimit",
doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
final def getTreeLimit: Double = $(lambdaBias)
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
minChildWeight -> 1, maxDeltaStep -> 0, minChildWeight -> 1, maxDeltaStep -> 0,
growPolicy -> "depthwise", maxBins -> 16, growPolicy -> "depthwise", maxBins -> 16,
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1, subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03, lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree", scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0) rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
} }
private[spark] object BoosterParams { private[spark] object BoosterParams {

View File

@ -158,6 +158,30 @@ private[spark] trait GeneralParams extends Params {
) )
} }
trait HasLeafPredictionCol extends Params {
/**
* Param for leaf prediction column name.
* @group param
*/
final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
"name of the predictLeaf results")
/** @group getParam */
final def getLeafPredictionCol: String = $(leafPredictionCol)
}
trait HasContribPredictionCol extends Params {
/**
* Param for contribution prediction column name.
* @group param
*/
final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
"name of the predictContrib results")
/** @group getParam */
final def getContribPredictionCol: String = $(contribPredictionCol)
}
trait HasBaseMarginCol extends Params { trait HasBaseMarginCol extends Params {
/** /**

View File

@ -211,4 +211,77 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
assert(testObjectiveHistory.length === 5) assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory) assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
} }
test("test predictionLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("predictLeaf")
val resultDF = model.transform(test)
assert(resultDF.count == groundTruth)
assert(resultDF.columns.contains("predictLeaf"))
}
test("test predictionLeaf with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("")
val resultDF = model.transform(test)
assert(!resultDF.columns.contains("predictLeaf"))
}
test("test predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setContribPredictionCol("predictContrib")
val resultDF = model.transform(buildDataFrame(Classification.test))
assert(resultDF.count == groundTruth)
assert(resultDF.columns.contains("predictContrib"))
}
test("test predictionContrib with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setContribPredictionCol("")
val resultDF = model.transform(test)
assert(!resultDF.columns.contains("predictContrib"))
}
test("test predictionLeaf and predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val test = buildDataFrame(Classification.test)
val groundTruth = test.count()
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("predictLeaf")
model.setContribPredictionCol("predictContrib")
val resultDF = model.transform(buildDataFrame(Classification.test))
assert(resultDF.count == groundTruth)
assert(resultDF.columns.contains("predictLeaf"))
assert(resultDF.columns.contains("predictContrib"))
}
} }

View File

@ -120,4 +120,72 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
val first = prediction.head.getAs[Double]("prediction") val first = prediction.head.getAs[Double]("prediction")
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f)) prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
} }
test("test predictionLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("predictLeaf")
val resultDF = model.transform(testDF)
assert(resultDF.count === groundTruth)
assert(resultDF.columns.contains("predictLeaf"))
}
test("test predictionLeaf with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("")
val resultDF = model.transform(testDF)
assert(!resultDF.columns.contains("predictLeaf"))
}
test("test predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setContribPredictionCol("predictContrib")
val resultDF = model.transform(testDF)
assert(resultDF.count === groundTruth)
assert(resultDF.columns.contains("predictContrib"))
}
test("test predictionContrib with empty column name") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setContribPredictionCol("")
val resultDF = model.transform(testDF)
assert(!resultDF.columns.contains("predictContrib"))
}
test("test predictionLeaf and predictionContrib") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val groundTruth = testDF.count()
val xgb = new XGBoostRegressor(paramMap)
val model = xgb.fit(training)
model.setLeafPredictionCol("predictLeaf")
model.setContribPredictionCol("predictContrib")
val resultDF = model.transform(testDF)
assert(resultDF.count === groundTruth)
assert(resultDF.columns.contains("predictLeaf"))
assert(resultDF.columns.contains("predictContrib"))
}
} }

View File

@ -125,8 +125,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
* @return predict result * @return predict result
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0) def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0):
: Array[Array[Float]] = { Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit) booster.predict(data.jDMatrix, outPutMargin, treeLimit)
} }
@ -139,7 +139,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = { def predictLeaf(data: DMatrix, treeLimit: Int = 0): Array[Array[Float]] = {
booster.predictLeaf(data.jDMatrix, treeLimit) booster.predictLeaf(data.jDMatrix, treeLimit)
} }