[jvm-packages] allowing chaining prediction (#4667)

* add test for chaining prediction

* update rabit

* Update XGBoostGeneralSuite.scala
This commit is contained in:
Nan Zhu 2019-07-17 08:50:27 -07:00 committed by GitHub
parent 3c506b076e
commit 01b0c9047c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 1 deletions

View File

@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files import java.nio.file.Files
import scala.util.Random
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@ -26,6 +28,8 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostGeneralSuite extends FunSuite with PerTest { class XGBoostGeneralSuite extends FunSuite with PerTest {
test("distributed training with the specified worker number") { test("distributed training with the specified worker number") {
@ -395,4 +399,22 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(clsRet1 sameElements clsRet3) assert(clsRet1 sameElements clsRet3)
assert(clsRet1 sameElements clsRet4) assert(clsRet1 sameElements clsRet4)
} }
test("chaining the prediction") {
val modelPath = getClass.getResource("/model/0.82/model").getPath
val model = XGBoostClassificationModel.read.load(modelPath)
val r = new Random(0)
val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
toDF("feature", "label").repartition(5)
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
val df1 = model.transform(assembler.transform(df)).withColumnRenamed(
"prediction", "prediction1").withColumnRenamed(
"rawPrediction", "rawPrediction1").withColumnRenamed(
"probability", "probability1")
val df2 = model.transform(df1)
df1.collect()
df2.collect()
}
} }

2
rabit

@ -1 +1 @@
Subproject commit 65b718a5e786bd7d0a850f3fa1df0dbdab023eb1 Subproject commit e1c8056f6a0ee1c42fd00430b74176e67db66a9f