[jvm-packages] allowing chaining prediction (#4667)
* add test for chaining prediction * update rabit * Update XGBoostGeneralSuite.scala
This commit is contained in:
parent
3c506b076e
commit
01b0c9047c
@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
@ -26,6 +28,8 @@ import org.apache.hadoop.fs.{FileSystem, Path}
|
|||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
|
||||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
test("distributed training with the specified worker number") {
|
test("distributed training with the specified worker number") {
|
||||||
@ -395,4 +399,22 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(clsRet1 sameElements clsRet3)
|
assert(clsRet1 sameElements clsRet3)
|
||||||
assert(clsRet1 sameElements clsRet4)
|
assert(clsRet1 sameElements clsRet4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("chaining the prediction") {
|
||||||
|
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||||
|
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||||
|
val r = new Random(0)
|
||||||
|
val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
|
||||||
|
toDF("feature", "label").repartition(5)
|
||||||
|
val assembler = new VectorAssembler()
|
||||||
|
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||||
|
.setOutputCol("features")
|
||||||
|
val df1 = model.transform(assembler.transform(df)).withColumnRenamed(
|
||||||
|
"prediction", "prediction1").withColumnRenamed(
|
||||||
|
"rawPrediction", "rawPrediction1").withColumnRenamed(
|
||||||
|
"probability", "probability1")
|
||||||
|
val df2 = model.transform(df1)
|
||||||
|
df1.collect()
|
||||||
|
df2.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit 65b718a5e786bd7d0a850f3fa1df0dbdab023eb1
|
Subproject commit e1c8056f6a0ee1c42fd00430b74176e67db66a9f
|
||||||
Loading…
x
Reference in New Issue
Block a user