[jvm-packages] enable predictLeaf/predictContrib/treeLimit in 0.8 (#3532)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * partial finish * no test * add test cases * add test cases * address comments * add test for regressor * fix typo
This commit is contained in:
parent
246ec92163
commit
1c08b3b2ea
@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
|||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params._
|
import ml.dmlc.xgboost4j.scala.spark.params._
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
|
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.ml.classification._
|
import org.apache.spark.ml.classification._
|
||||||
import org.apache.spark.ml.linalg._
|
import org.apache.spark.ml.linalg._
|
||||||
@ -39,8 +39,11 @@ import org.apache.spark.sql.types._
|
|||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.json4s.DefaultFormats
|
import org.json4s.DefaultFormats
|
||||||
|
|
||||||
|
import org.apache.spark.broadcast.Broadcast
|
||||||
|
|
||||||
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams
|
||||||
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs
|
||||||
|
with HasLeafPredictionCol with HasContribPredictionCol
|
||||||
|
|
||||||
class XGBoostClassifier (
|
class XGBoostClassifier (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
@ -235,6 +238,12 @@ class XGBoostClassificationModel private[ml](
|
|||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
|
||||||
|
|
||||||
|
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
|
||||||
|
|
||||||
|
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Single instance prediction.
|
* Single instance prediction.
|
||||||
* Note: The performance is not ideal, use it carefully!
|
* Note: The performance is not ideal, use it carefully!
|
||||||
@ -287,22 +296,15 @@ class XGBoostClassificationModel private[ml](
|
|||||||
null
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val dm = new DMatrix(
|
val dm = new DMatrix(
|
||||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
try {
|
try {
|
||||||
val rawPredictionItr = {
|
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||||
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator
|
producePredictionItrs(bBooster, dm)
|
||||||
}
|
|
||||||
val probabilityItr = {
|
|
||||||
bBooster.value.predict(dm, outPutMargin = false).map(Row(_)).iterator
|
|
||||||
}
|
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
rowItr1.zip(rawPredictionItr).zip(probabilityItr).map {
|
produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr,
|
||||||
case ((originals: Row, rawPrediction: Row), probability: Row) =>
|
predContribItr)
|
||||||
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
|
|
||||||
}
|
|
||||||
} finally {
|
} finally {
|
||||||
dm.delete()
|
dm.delete()
|
||||||
}
|
}
|
||||||
@ -313,7 +315,82 @@ class XGBoostClassificationModel private[ml](
|
|||||||
|
|
||||||
bBooster.unpersist(blocking = false)
|
bBooster.unpersist(blocking = false)
|
||||||
|
|
||||||
dataset.sparkSession.createDataFrame(rdd, schema)
|
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def produceResultIterator(
|
||||||
|
originalRowItr: Iterator[Row],
|
||||||
|
rawPredictionItr: Iterator[Row],
|
||||||
|
probabilityItr: Iterator[Row],
|
||||||
|
predLeafItr: Iterator[Row],
|
||||||
|
predContribItr: Iterator[Row]): Iterator[Row] = {
|
||||||
|
// the following implementation is to be improved
|
||||||
|
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
||||||
|
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
||||||
|
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
|
||||||
|
map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
|
||||||
|
contribs: Row) =>
|
||||||
|
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
|
||||||
|
contribs.toSeq)
|
||||||
|
}
|
||||||
|
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
||||||
|
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
|
||||||
|
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
|
||||||
|
map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
|
||||||
|
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq)
|
||||||
|
}
|
||||||
|
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
|
||||||
|
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
||||||
|
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr).
|
||||||
|
map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) =>
|
||||||
|
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map {
|
||||||
|
case ((originals: Row, rawPrediction: Row), probability: Row) =>
|
||||||
|
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def generateResultSchema(fixedSchema: StructType): StructType = {
|
||||||
|
var resultSchema = fixedSchema
|
||||||
|
if (isDefined(leafPredictionCol)) {
|
||||||
|
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
|
}
|
||||||
|
if (isDefined(contribPredictionCol)) {
|
||||||
|
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
|
}
|
||||||
|
resultSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
||||||
|
Array[Iterator[Row]] = {
|
||||||
|
val rawPredictionItr = {
|
||||||
|
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
|
||||||
|
map(Row(_)).iterator
|
||||||
|
}
|
||||||
|
val probabilityItr = {
|
||||||
|
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).
|
||||||
|
map(Row(_)).iterator
|
||||||
|
}
|
||||||
|
val predLeafItr = {
|
||||||
|
if (isDefined(leafPredictionCol)) {
|
||||||
|
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
|
||||||
|
} else {
|
||||||
|
Iterator()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val predContribItr = {
|
||||||
|
if (isDefined(contribPredictionCol)) {
|
||||||
|
broadcastBooster.value.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
|
||||||
|
} else {
|
||||||
|
Iterator()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
@ -329,11 +406,11 @@ class XGBoostClassificationModel private[ml](
|
|||||||
var outputData = transformInternal(dataset)
|
var outputData = transformInternal(dataset)
|
||||||
var numColsOutput = 0
|
var numColsOutput = 0
|
||||||
|
|
||||||
val rawPredictionUDF = udf { (rawPrediction: mutable.WrappedArray[Float]) =>
|
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
|
||||||
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
|
Vectors.dense(rawPrediction.map(_.toDouble).toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
val probabilityUDF = udf { (probability: mutable.WrappedArray[Float]) =>
|
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
if (numClasses == 2) {
|
if (numClasses == 2) {
|
||||||
Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble))
|
Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble))
|
||||||
} else {
|
} else {
|
||||||
@ -341,7 +418,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val predictUDF = udf { (probability: mutable.WrappedArray[Float]) =>
|
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
|
||||||
// From XGBoost probability to MLlib prediction
|
// From XGBoost probability to MLlib prediction
|
||||||
val probabilities = if (numClasses == 2) {
|
val probabilities = if (numClasses == 2) {
|
||||||
Array(1 - probability(0), probability(0)).map(_.toDouble)
|
Array(1 - probability(0), probability(0)).map(_.toDouble)
|
||||||
|
|||||||
@ -24,8 +24,8 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
|||||||
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||||
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
|
||||||
|
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
|
||||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||||
@ -37,12 +37,13 @@ import org.apache.spark.sql._
|
|||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.json4s.DefaultFormats
|
import org.json4s.DefaultFormats
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import org.apache.spark.broadcast.Broadcast
|
||||||
|
|
||||||
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
|
private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams
|
||||||
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
|
with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol
|
||||||
with ParamMapFuncs
|
with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol
|
||||||
|
|
||||||
class XGBoostRegressor (
|
class XGBoostRegressor (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
@ -231,6 +232,12 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
|
||||||
|
|
||||||
|
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
|
||||||
|
|
||||||
|
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Single instance prediction.
|
* Single instance prediction.
|
||||||
* Note: The performance is not ideal, use it carefully!
|
* Note: The performance is not ideal, use it carefully!
|
||||||
@ -270,14 +277,10 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
try {
|
try {
|
||||||
val originalPredictionItr = {
|
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
||||||
bBooster.value.predict(dm).map(Row(_)).iterator
|
producePredictionItrs(bBooster, dm)
|
||||||
}
|
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
rowItr1.zip(originalPredictionItr).map {
|
produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr)
|
||||||
case (originals: Row, originalPrediction: Row) =>
|
|
||||||
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
|
|
||||||
}
|
|
||||||
} finally {
|
} finally {
|
||||||
dm.delete()
|
dm.delete()
|
||||||
}
|
}
|
||||||
@ -285,10 +288,77 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
Iterator[Row]()
|
Iterator[Row]()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bBooster.unpersist(blocking = false)
|
bBooster.unpersist(blocking = false)
|
||||||
|
dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema))
|
||||||
|
}
|
||||||
|
|
||||||
dataset.sparkSession.createDataFrame(rdd, schema)
|
private def produceResultIterator(
|
||||||
|
originalRowItr: Iterator[Row],
|
||||||
|
predictionItr: Iterator[Row],
|
||||||
|
predLeafItr: Iterator[Row],
|
||||||
|
predContribItr: Iterator[Row]): Iterator[Row] = {
|
||||||
|
// the following implementation is to be improved
|
||||||
|
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
||||||
|
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
||||||
|
originalRowItr.zip(predictionItr).zip(predLeafItr).zip(predContribItr).
|
||||||
|
map { case (((originals: Row, prediction: Row), leaves: Row), contribs: Row) =>
|
||||||
|
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq ++ contribs.toSeq)
|
||||||
|
}
|
||||||
|
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
|
||||||
|
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
|
||||||
|
originalRowItr.zip(predictionItr).zip(predLeafItr).
|
||||||
|
map { case ((originals: Row, prediction: Row), leaves: Row) =>
|
||||||
|
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq)
|
||||||
|
}
|
||||||
|
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
|
||||||
|
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
|
||||||
|
originalRowItr.zip(predictionItr).zip(predContribItr).
|
||||||
|
map { case ((originals: Row, prediction: Row), contribs: Row) =>
|
||||||
|
Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ contribs.toSeq)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
originalRowItr.zip(predictionItr).map {
|
||||||
|
case (originals: Row, originalPrediction: Row) =>
|
||||||
|
Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def generateResultSchema(fixedSchema: StructType): StructType = {
|
||||||
|
var resultSchema = fixedSchema
|
||||||
|
if (isDefined(leafPredictionCol)) {
|
||||||
|
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
|
}
|
||||||
|
if (isDefined(contribPredictionCol)) {
|
||||||
|
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
|
||||||
|
ArrayType(FloatType, containsNull = false), nullable = false))
|
||||||
|
}
|
||||||
|
resultSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
|
||||||
|
Array[Iterator[Row]] = {
|
||||||
|
val originalPredictionItr = {
|
||||||
|
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
|
||||||
|
}
|
||||||
|
val predLeafItr = {
|
||||||
|
if (isDefined(leafPredictionCol)) {
|
||||||
|
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).
|
||||||
|
map(Row(_)).iterator
|
||||||
|
} else {
|
||||||
|
Iterator()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val predContribItr = {
|
||||||
|
if (isDefined(contribPredictionCol)) {
|
||||||
|
broadcastBooster.value.predictContrib(dm, $(treeLimit)).
|
||||||
|
map(Row(_)).iterator
|
||||||
|
} else {
|
||||||
|
Iterator()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Array(originalPredictionItr, predLeafItr, predContribItr)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
|
|||||||
@ -237,13 +237,18 @@ private[spark] trait BoosterParams extends Params {
|
|||||||
|
|
||||||
final def getLambdaBias: Double = $(lambdaBias)
|
final def getLambdaBias: Double = $(lambdaBias)
|
||||||
|
|
||||||
|
final val treeLimit = new IntParam(this, name = "treeLimit",
|
||||||
|
doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
|
||||||
|
|
||||||
|
final def getTreeLimit: Double = $(lambdaBias)
|
||||||
|
|
||||||
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||||
growPolicy -> "depthwise", maxBins -> 16,
|
growPolicy -> "depthwise", maxBins -> 16,
|
||||||
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
|
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
|
||||||
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
|
||||||
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
|
||||||
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0)
|
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] object BoosterParams {
|
private[spark] object BoosterParams {
|
||||||
|
|||||||
@ -158,6 +158,30 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait HasLeafPredictionCol extends Params {
|
||||||
|
/**
|
||||||
|
* Param for leaf prediction column name.
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
|
||||||
|
"name of the predictLeaf results")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getLeafPredictionCol: String = $(leafPredictionCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
trait HasContribPredictionCol extends Params {
|
||||||
|
/**
|
||||||
|
* Param for contribution prediction column name.
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
|
||||||
|
"name of the predictContrib results")
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getContribPredictionCol: String = $(contribPredictionCol)
|
||||||
|
}
|
||||||
|
|
||||||
trait HasBaseMarginCol extends Params {
|
trait HasBaseMarginCol extends Params {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -211,4 +211,77 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
|
|||||||
assert(testObjectiveHistory.length === 5)
|
assert(testObjectiveHistory.length === 5)
|
||||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test predictionLeaf") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
|
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
val test = buildDataFrame(Classification.test)
|
||||||
|
val groundTruth = test.count()
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setLeafPredictionCol("predictLeaf")
|
||||||
|
val resultDF = model.transform(test)
|
||||||
|
assert(resultDF.count == groundTruth)
|
||||||
|
assert(resultDF.columns.contains("predictLeaf"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test predictionLeaf with empty column name") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
|
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
val test = buildDataFrame(Classification.test)
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setLeafPredictionCol("")
|
||||||
|
val resultDF = model.transform(test)
|
||||||
|
assert(!resultDF.columns.contains("predictLeaf"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test predictionContrib") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
|
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
val test = buildDataFrame(Classification.test)
|
||||||
|
val groundTruth = test.count()
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setContribPredictionCol("predictContrib")
|
||||||
|
val resultDF = model.transform(buildDataFrame(Classification.test))
|
||||||
|
assert(resultDF.count == groundTruth)
|
||||||
|
assert(resultDF.columns.contains("predictContrib"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test predictionContrib with empty column name") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
|
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
val test = buildDataFrame(Classification.test)
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setContribPredictionCol("")
|
||||||
|
val resultDF = model.transform(test)
|
||||||
|
assert(!resultDF.columns.contains("predictContrib"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test predictionLeaf and predictionContrib") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||||
|
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Classification.train)
|
||||||
|
val test = buildDataFrame(Classification.test)
|
||||||
|
val groundTruth = test.count()
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setLeafPredictionCol("predictLeaf")
|
||||||
|
model.setContribPredictionCol("predictContrib")
|
||||||
|
val resultDF = model.transform(buildDataFrame(Classification.test))
|
||||||
|
assert(resultDF.count == groundTruth)
|
||||||
|
assert(resultDF.columns.contains("predictLeaf"))
|
||||||
|
assert(resultDF.columns.contains("predictContrib"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -120,4 +120,72 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
|
|||||||
val first = prediction.head.getAs[Double]("prediction")
|
val first = prediction.head.getAs[Double]("prediction")
|
||||||
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
|
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test predictionLeaf") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Regression.train)
|
||||||
|
val testDF = buildDataFrame(Regression.test)
|
||||||
|
val groundTruth = testDF.count()
|
||||||
|
val xgb = new XGBoostRegressor(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setLeafPredictionCol("predictLeaf")
|
||||||
|
val resultDF = model.transform(testDF)
|
||||||
|
assert(resultDF.count === groundTruth)
|
||||||
|
assert(resultDF.columns.contains("predictLeaf"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test predictionLeaf with empty column name") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Regression.train)
|
||||||
|
val testDF = buildDataFrame(Regression.test)
|
||||||
|
val xgb = new XGBoostRegressor(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setLeafPredictionCol("")
|
||||||
|
val resultDF = model.transform(testDF)
|
||||||
|
assert(!resultDF.columns.contains("predictLeaf"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test predictionContrib") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Regression.train)
|
||||||
|
val testDF = buildDataFrame(Regression.test)
|
||||||
|
val groundTruth = testDF.count()
|
||||||
|
val xgb = new XGBoostRegressor(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setContribPredictionCol("predictContrib")
|
||||||
|
val resultDF = model.transform(testDF)
|
||||||
|
assert(resultDF.count === groundTruth)
|
||||||
|
assert(resultDF.columns.contains("predictContrib"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test predictionContrib with empty column name") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Regression.train)
|
||||||
|
val testDF = buildDataFrame(Regression.test)
|
||||||
|
val xgb = new XGBoostRegressor(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setContribPredictionCol("")
|
||||||
|
val resultDF = model.transform(testDF)
|
||||||
|
assert(!resultDF.columns.contains("predictContrib"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test predictionLeaf and predictionContrib") {
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||||
|
val training = buildDataFrame(Regression.train)
|
||||||
|
val testDF = buildDataFrame(Regression.test)
|
||||||
|
val groundTruth = testDF.count()
|
||||||
|
val xgb = new XGBoostRegressor(paramMap)
|
||||||
|
val model = xgb.fit(training)
|
||||||
|
model.setLeafPredictionCol("predictLeaf")
|
||||||
|
model.setContribPredictionCol("predictContrib")
|
||||||
|
val resultDF = model.transform(testDF)
|
||||||
|
assert(resultDF.count === groundTruth)
|
||||||
|
assert(resultDF.columns.contains("predictLeaf"))
|
||||||
|
assert(resultDF.columns.contains("predictContrib"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -125,8 +125,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
* @return predict result
|
* @return predict result
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0)
|
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0):
|
||||||
: Array[Array[Float]] = {
|
Array[Array[Float]] = {
|
||||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
|
def predictLeaf(data: DMatrix, treeLimit: Int = 0): Array[Array[Float]] = {
|
||||||
booster.predictLeaf(data.jDMatrix, treeLimit)
|
booster.predictLeaf(data.jDMatrix, treeLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user