[jvm-packages] allowing chaining prediction (#4667)
* add test for chaining prediction * update rabit * Update XGBoostGeneralSuite.scala
This commit is contained in:
@@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
@@ -26,6 +28,8 @@ import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.TaskContext
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("distributed training with the specified worker number") {
|
||||
@@ -395,4 +399,22 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(clsRet1 sameElements clsRet3)
|
||||
assert(clsRet1 sameElements clsRet4)
|
||||
}
|
||||
|
||||
test("chaining the prediction") {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||
val r = new Random(0)
|
||||
val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
|
||||
toDF("feature", "label").repartition(5)
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
val df1 = model.transform(assembler.transform(df)).withColumnRenamed(
|
||||
"prediction", "prediction1").withColumnRenamed(
|
||||
"rawPrediction", "rawPrediction1").withColumnRenamed(
|
||||
"probability", "probability1")
|
||||
val df2 = model.transform(df1)
|
||||
df1.collect()
|
||||
df2.collect()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user