diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 5f81393e0..ec76c9177 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.params._ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} - import org.apache.hadoop.fs.Path + import org.apache.spark.TaskContext import org.apache.spark.ml.classification._ import org.apache.spark.ml.linalg._ @@ -39,8 +39,11 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql._ import org.json4s.DefaultFormats +import org.apache.spark.broadcast.Broadcast + private[spark] trait XGBoostClassifierParams extends GeneralParams with LearningTaskParams with BoosterParams with HasWeightCol with HasBaseMarginCol with HasNumClass with ParamMapFuncs + with HasLeafPredictionCol with HasContribPredictionCol class XGBoostClassifier ( override val uid: String, @@ -235,6 +238,12 @@ class XGBoostClassificationModel private[ml]( this } + def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value) + + def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value) + + def setTreeLimit(value: Int): this.type = set(treeLimit, value) + /** * Single instance prediction. * Note: The performance is not ideal, use it carefully! @@ -287,22 +296,15 @@ class XGBoostClassificationModel private[ml]( null } } - val dm = new DMatrix( XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), cacheInfo) try { - val rawPredictionItr = { - bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator - } - val probabilityItr = { - bBooster.value.predict(dm, outPutMargin = false).map(Row(_)).iterator - } + val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) = + producePredictionItrs(bBooster, dm) Rabit.shutdown() - rowItr1.zip(rawPredictionItr).zip(probabilityItr).map { - case ((originals: Row, rawPrediction: Row), probability: Row) => - Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq) - } + produceResultIterator(rowItr1, rawPredictionItr, probabilityItr, predLeafItr, + predContribItr) } finally { dm.delete() } @@ -313,7 +315,82 @@ class XGBoostClassificationModel private[ml]( bBooster.unpersist(blocking = false) - dataset.sparkSession.createDataFrame(rdd, schema) + dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema)) + } + + private def produceResultIterator( + originalRowItr: Iterator[Row], + rawPredictionItr: Iterator[Row], + probabilityItr: Iterator[Row], + predLeafItr: Iterator[Row], + predContribItr: Iterator[Row]): Iterator[Row] = { + // the following implementation is to be improved + if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty && + isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) { + originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr). + map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row), + contribs: Row) => + Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++ + contribs.toSeq) + } + } else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty && + (!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) { + originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr). + map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) => + Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq) + } + } else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) && + isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) { + originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr). + map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) => + Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq) + } + } else { + originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map { + case ((originals: Row, rawPrediction: Row), probability: Row) => + Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq) + } + } + } + + private def generateResultSchema(fixedSchema: StructType): StructType = { + var resultSchema = fixedSchema + if (isDefined(leafPredictionCol)) { + resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) + } + if (isDefined(contribPredictionCol)) { + resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) + } + resultSchema + } + + private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix): + Array[Iterator[Row]] = { + val rawPredictionItr = { + broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)). + map(Row(_)).iterator + } + val probabilityItr = { + broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)). + map(Row(_)).iterator + } + val predLeafItr = { + if (isDefined(leafPredictionCol)) { + broadcastBooster.value.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator + } else { + Iterator() + } + } + val predContribItr = { + if (isDefined(contribPredictionCol)) { + broadcastBooster.value.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator + } else { + Iterator() + } + } + Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) } override def transform(dataset: Dataset[_]): DataFrame = { @@ -329,11 +406,11 @@ class XGBoostClassificationModel private[ml]( var outputData = transformInternal(dataset) var numColsOutput = 0 - val rawPredictionUDF = udf { (rawPrediction: mutable.WrappedArray[Float]) => + val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] => Vectors.dense(rawPrediction.map(_.toDouble).toArray) } - val probabilityUDF = udf { (probability: mutable.WrappedArray[Float]) => + val probabilityUDF = udf { probability: mutable.WrappedArray[Float] => if (numClasses == 2) { Vectors.dense(Array(1 - probability(0), probability(0)).map(_.toDouble)) } else { @@ -341,7 +418,7 @@ class XGBoostClassificationModel private[ml]( } } - val predictUDF = udf { (probability: mutable.WrappedArray[Float]) => + val predictUDF = udf { probability: mutable.WrappedArray[Float] => // From XGBoost probability to MLlib prediction val probabilities = if (numClasses == 2) { Array(1 - probability(0), probability(0)).map(_.toDouble) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 2d6568300..0fe3452d6 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -24,8 +24,8 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} - import org.apache.hadoop.fs.Path + import org.apache.spark.TaskContext import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.ml.param.shared.HasWeightCol @@ -37,12 +37,13 @@ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.json4s.DefaultFormats - import scala.collection.mutable +import org.apache.spark.broadcast.Broadcast + private[spark] trait XGBoostRegressorParams extends GeneralParams with BoosterParams with LearningTaskParams with HasBaseMarginCol with HasWeightCol with HasGroupCol - with ParamMapFuncs + with ParamMapFuncs with HasLeafPredictionCol with HasContribPredictionCol class XGBoostRegressor ( override val uid: String, @@ -231,6 +232,12 @@ class XGBoostRegressionModel private[ml] ( this } + def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value) + + def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value) + + def setTreeLimit(value: Int): this.type = set(treeLimit, value) + /** * Single instance prediction. * Note: The performance is not ideal, use it carefully! @@ -270,14 +277,10 @@ class XGBoostRegressionModel private[ml] ( XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)), cacheInfo) try { - val originalPredictionItr = { - bBooster.value.predict(dm).map(Row(_)).iterator - } + val Array(originalPredictionItr, predLeafItr, predContribItr) = + producePredictionItrs(bBooster, dm) Rabit.shutdown() - rowItr1.zip(originalPredictionItr).map { - case (originals: Row, originalPrediction: Row) => - Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq) - } + produceResultIterator(rowItr1, originalPredictionItr, predLeafItr, predContribItr) } finally { dm.delete() } @@ -285,10 +288,77 @@ class XGBoostRegressionModel private[ml] ( Iterator[Row]() } } - bBooster.unpersist(blocking = false) + dataset.sparkSession.createDataFrame(rdd, generateResultSchema(schema)) + } - dataset.sparkSession.createDataFrame(rdd, schema) + private def produceResultIterator( + originalRowItr: Iterator[Row], + predictionItr: Iterator[Row], + predLeafItr: Iterator[Row], + predContribItr: Iterator[Row]): Iterator[Row] = { + // the following implementation is to be improved + if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty && + isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) { + originalRowItr.zip(predictionItr).zip(predLeafItr).zip(predContribItr). + map { case (((originals: Row, prediction: Row), leaves: Row), contribs: Row) => + Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq ++ contribs.toSeq) + } + } else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty && + (!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) { + originalRowItr.zip(predictionItr).zip(predLeafItr). + map { case ((originals: Row, prediction: Row), leaves: Row) => + Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq) + } + } else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) && + isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) { + originalRowItr.zip(predictionItr).zip(predContribItr). + map { case ((originals: Row, prediction: Row), contribs: Row) => + Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ contribs.toSeq) + } + } else { + originalRowItr.zip(predictionItr).map { + case (originals: Row, originalPrediction: Row) => + Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq) + } + } + } + + private def generateResultSchema(fixedSchema: StructType): StructType = { + var resultSchema = fixedSchema + if (isDefined(leafPredictionCol)) { + resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) + } + if (isDefined(contribPredictionCol)) { + resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType = + ArrayType(FloatType, containsNull = false), nullable = false)) + } + resultSchema + } + + private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix): + Array[Iterator[Row]] = { + val originalPredictionItr = { + broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator + } + val predLeafItr = { + if (isDefined(leafPredictionCol)) { + broadcastBooster.value.predictLeaf(dm, $(treeLimit)). + map(Row(_)).iterator + } else { + Iterator() + } + } + val predContribItr = { + if (isDefined(contribPredictionCol)) { + broadcastBooster.value.predictContrib(dm, $(treeLimit)). + map(Row(_)).iterator + } else { + Iterator() + } + } + Array(originalPredictionItr, predLeafItr, predContribItr) } override def transform(dataset: Dataset[_]): DataFrame = { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala index c3be3601b..89e2ef102 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala @@ -237,13 +237,18 @@ private[spark] trait BoosterParams extends Params { final def getLambdaBias: Double = $(lambdaBias) + final val treeLimit = new IntParam(this, name = "treeLimit", + doc = "number of trees used in the prediction; defaults to 0 (use all trees).") + + final def getTreeLimit: Double = $(lambdaBias) + setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, minChildWeight -> 1, maxDeltaStep -> 0, growPolicy -> "depthwise", maxBins -> 16, subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1, lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03, scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree", - rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0) + rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0) } private[spark] object BoosterParams { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 6a27195b8..f0ad69ad5 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -158,6 +158,30 @@ private[spark] trait GeneralParams extends Params { ) } +trait HasLeafPredictionCol extends Params { + /** + * Param for leaf prediction column name. + * @group param + */ + final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol", + "name of the predictLeaf results") + + /** @group getParam */ + final def getLeafPredictionCol: String = $(leafPredictionCol) +} + +trait HasContribPredictionCol extends Params { + /** + * Param for contribution prediction column name. + * @group param + */ + final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol", + "name of the predictContrib results") + + /** @group getParam */ + final def getContribPredictionCol: String = $(contribPredictionCol) +} + trait HasBaseMarginCol extends Params { /** diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index d2814b8a1..86f9b575a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -211,4 +211,77 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { assert(testObjectiveHistory.length === 5) assert(model.summary.trainObjectiveHistory !== testObjectiveHistory) } + + test("test predictionLeaf") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", + "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Classification.train) + val test = buildDataFrame(Classification.test) + val groundTruth = test.count() + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(training) + model.setLeafPredictionCol("predictLeaf") + val resultDF = model.transform(test) + assert(resultDF.count == groundTruth) + assert(resultDF.columns.contains("predictLeaf")) + } + + test("test predictionLeaf with empty column name") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", + "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Classification.train) + val test = buildDataFrame(Classification.test) + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(training) + model.setLeafPredictionCol("") + val resultDF = model.transform(test) + assert(!resultDF.columns.contains("predictLeaf")) + } + + test("test predictionContrib") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", + "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Classification.train) + val test = buildDataFrame(Classification.test) + val groundTruth = test.count() + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(training) + model.setContribPredictionCol("predictContrib") + val resultDF = model.transform(buildDataFrame(Classification.test)) + assert(resultDF.count == groundTruth) + assert(resultDF.columns.contains("predictContrib")) + } + + test("test predictionContrib with empty column name") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", + "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Classification.train) + val test = buildDataFrame(Classification.test) + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(training) + model.setContribPredictionCol("") + val resultDF = model.transform(test) + assert(!resultDF.columns.contains("predictContrib")) + } + + test("test predictionLeaf and predictionContrib") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "train_test_ratio" -> "0.5", + "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Classification.train) + val test = buildDataFrame(Classification.test) + val groundTruth = test.count() + val xgb = new XGBoostClassifier(paramMap) + val model = xgb.fit(training) + model.setLeafPredictionCol("predictLeaf") + model.setContribPredictionCol("predictContrib") + val resultDF = model.transform(buildDataFrame(Classification.test)) + assert(resultDF.count == groundTruth) + assert(resultDF.columns.contains("predictLeaf")) + assert(resultDF.columns.contains("predictContrib")) + } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 8dba73e61..8679a1517 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -120,4 +120,72 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { val first = prediction.head.getAs[Double]("prediction") prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f)) } + + test("test predictionLeaf") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Regression.train) + val testDF = buildDataFrame(Regression.test) + val groundTruth = testDF.count() + val xgb = new XGBoostRegressor(paramMap) + val model = xgb.fit(training) + model.setLeafPredictionCol("predictLeaf") + val resultDF = model.transform(testDF) + assert(resultDF.count === groundTruth) + assert(resultDF.columns.contains("predictLeaf")) + } + + test("test predictionLeaf with empty column name") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Regression.train) + val testDF = buildDataFrame(Regression.test) + val xgb = new XGBoostRegressor(paramMap) + val model = xgb.fit(training) + model.setLeafPredictionCol("") + val resultDF = model.transform(testDF) + assert(!resultDF.columns.contains("predictLeaf")) + } + + test("test predictionContrib") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Regression.train) + val testDF = buildDataFrame(Regression.test) + val groundTruth = testDF.count() + val xgb = new XGBoostRegressor(paramMap) + val model = xgb.fit(training) + model.setContribPredictionCol("predictContrib") + val resultDF = model.transform(testDF) + assert(resultDF.count === groundTruth) + assert(resultDF.columns.contains("predictContrib")) + } + + test("test predictionContrib with empty column name") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Regression.train) + val testDF = buildDataFrame(Regression.test) + val xgb = new XGBoostRegressor(paramMap) + val model = xgb.fit(training) + model.setContribPredictionCol("") + val resultDF = model.transform(testDF) + assert(!resultDF.columns.contains("predictContrib")) + } + + test("test predictionLeaf and predictionContrib") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers) + val training = buildDataFrame(Regression.train) + val testDF = buildDataFrame(Regression.test) + val groundTruth = testDF.count() + val xgb = new XGBoostRegressor(paramMap) + val model = xgb.fit(training) + model.setLeafPredictionCol("predictLeaf") + model.setContribPredictionCol("predictContrib") + val resultDF = model.transform(testDF) + assert(resultDF.count === groundTruth) + assert(resultDF.columns.contains("predictLeaf")) + assert(resultDF.columns.contains("predictContrib")) + } } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index f885a6881..c9013d3c7 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -125,8 +125,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) * @return predict result */ @throws(classOf[XGBoostError]) - def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0) - : Array[Array[Float]] = { + def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0): + Array[Array[Float]] = { booster.predict(data.jDMatrix, outPutMargin, treeLimit) } @@ -139,7 +139,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) * @throws XGBoostError native error */ @throws(classOf[XGBoostError]) - def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = { + def predictLeaf(data: DMatrix, treeLimit: Int = 0): Array[Array[Float]] = { booster.predictLeaf(data.jDMatrix, treeLimit) }