[jvm-packages] allowing chaining prediction (#4667)
* add test for chaining prediction * update rabit * Update XGBoostGeneralSuite.scala
This commit is contained in:
@@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
@@ -26,6 +28,8 @@ import org.apache.hadoop.fs.{FileSystem, Path}
|
|||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
|
||||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
test("distributed training with the specified worker number") {
|
test("distributed training with the specified worker number") {
|
||||||
@@ -395,4 +399,22 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(clsRet1 sameElements clsRet3)
|
assert(clsRet1 sameElements clsRet3)
|
||||||
assert(clsRet1 sameElements clsRet4)
|
assert(clsRet1 sameElements clsRet4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("chaining the prediction") {
|
||||||
|
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||||
|
val model = XGBoostClassificationModel.read.load(modelPath)
|
||||||
|
val r = new Random(0)
|
||||||
|
val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
|
||||||
|
toDF("feature", "label").repartition(5)
|
||||||
|
val assembler = new VectorAssembler()
|
||||||
|
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||||
|
.setOutputCol("features")
|
||||||
|
val df1 = model.transform(assembler.transform(df)).withColumnRenamed(
|
||||||
|
"prediction", "prediction1").withColumnRenamed(
|
||||||
|
"rawPrediction", "rawPrediction1").withColumnRenamed(
|
||||||
|
"probability", "probability1")
|
||||||
|
val df2 = model.transform(df1)
|
||||||
|
df1.collect()
|
||||||
|
df2.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
2
rabit
2
rabit
Submodule rabit updated: 65b718a5e7...e1c8056f6a
Reference in New Issue
Block a user