[jvm-packages] allowing chaining prediction (#4667)
* add test for chaining prediction * update rabit * Update XGBoostGeneralSuite.scala
This commit is contained in:
parent
3c506b076e
commit
01b0c9047c
@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
@ -26,6 +28,8 @@ import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.TaskContext
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("distributed training with the specified worker number") {
|
||||
@ -395,4 +399,22 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(clsRet1 sameElements clsRet3)
|
||||
assert(clsRet1 sameElements clsRet4)
|
||||
}
|
||||
|
||||
test("chaining the prediction") {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||
val r = new Random(0)
|
||||
val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
|
||||
toDF("feature", "label").repartition(5)
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
val df1 = model.transform(assembler.transform(df)).withColumnRenamed(
|
||||
"prediction", "prediction1").withColumnRenamed(
|
||||
"rawPrediction", "rawPrediction1").withColumnRenamed(
|
||||
"probability", "probability1")
|
||||
val df2 = model.transform(df1)
|
||||
df1.collect()
|
||||
df2.collect()
|
||||
}
|
||||
}
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit 65b718a5e786bd7d0a850f3fa1df0dbdab023eb1
|
||||
Subproject commit e1c8056f6a0ee1c42fd00430b74176e67db66a9f
|
||||
Loading…
x
Reference in New Issue
Block a user