[jvm-packages] handle NaN as missing value explicitly (#4309)
* handle nan * handle nan explicitly * make code better and handle sparse vector in spark * Update XGBoostGeneralSuite.scala
This commit is contained in:
@@ -33,6 +33,8 @@ import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
@@ -227,26 +229,45 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
def buildDenseDataFrame(): DataFrame = {
|
||||
val numRows = 100
|
||||
val numCols = 5
|
||||
|
||||
val data = (0 until numRows).map { x =>
|
||||
val label = Random.nextInt(2)
|
||||
val values = Array.tabulate[Double](numCols) { c =>
|
||||
if (c == numCols - 1) -0.1 else Random.nextDouble
|
||||
if (c == numCols - 1) 0 else Random.nextDouble
|
||||
}
|
||||
|
||||
(label, Vectors.dense(values))
|
||||
}
|
||||
|
||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
||||
}
|
||||
|
||||
val denseDF = buildDenseDataFrame().repartition(4)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
|
||||
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
||||
model.transform(denseDF).collect()
|
||||
}
|
||||
|
||||
test("handle Float.NaN as missing value correctly") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
val testDF = Seq(
|
||||
(1.0f, 0.0f, Float.NaN, 1.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(0.0f, 1.0f, 0.0f, 0.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(1.0f, Float.NaN, 0.0f, 0.0),
|
||||
(0.0f, 0.0f, 0.0f, 0.0),
|
||||
(0.0f, 1.0f, 0.0f, 1.0),
|
||||
(Float.NaN, 0.0f, 0.0f, 1.0)
|
||||
).toDF("col1", "col2", "col3", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3"))
|
||||
.setOutputCol("features")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
model.transform(inputDF).collect()
|
||||
}
|
||||
|
||||
test("training with spark parallelism checks disabled") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
Reference in New Issue
Block a user