From 01b0c9047c686a96ab73ff672194e2dff47e7923 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Wed, 17 Jul 2019 08:50:27 -0700 Subject: [PATCH] [jvm-packages] allowing chaining prediction (#4667) * add test for chaining prediction * update rabit * Update XGBoostGeneralSuite.scala --- .../scala/spark/XGBoostGeneralSuite.scala | 22 +++++++++++++++++++ rabit | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 36400f99f..4a06c36f2 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark import java.nio.file.Files +import scala.util.Random + import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} @@ -26,6 +28,8 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.TaskContext import org.scalatest.FunSuite +import org.apache.spark.ml.feature.VectorAssembler + class XGBoostGeneralSuite extends FunSuite with PerTest { test("distributed training with the specified worker number") { @@ -395,4 +399,22 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(clsRet1 sameElements clsRet3) assert(clsRet1 sameElements clsRet4) } + + test("chaining the prediction") { + val modelPath = getClass.getResource("/model/0.82/model").getPath + val model = XGBoostClassificationModel.read.load(modelPath) + val r = new Random(0) + val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))). + toDF("feature", "label").repartition(5) + val assembler = new VectorAssembler() + .setInputCols(df.columns.filter(!_.contains("label"))) + .setOutputCol("features") + val df1 = model.transform(assembler.transform(df)).withColumnRenamed( + "prediction", "prediction1").withColumnRenamed( + "rawPrediction", "rawPrediction1").withColumnRenamed( + "probability", "probability1") + val df2 = model.transform(df1) + df1.collect() + df2.collect() + } } diff --git a/rabit b/rabit index 65b718a5e..e1c8056f6 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 65b718a5e786bd7d0a850f3fa1df0dbdab023eb1 +Subproject commit e1c8056f6a0ee1c42fd00430b74176e67db66a9f