[jvm-packages] handle NaN as missing value explicitly (#4309)

* handle nan

* handle nan explicitly

* make code better and handle sparse vector in spark

* Update XGBoostGeneralSuite.scala
This commit is contained in:
Nan Zhu
2019-03-30 19:34:26 +08:00
committed by GitHub
parent 7ea5b772fb
commit ad4de0d718
4 changed files with 72 additions and 28 deletions

View File

@@ -33,6 +33,8 @@ import scala.util.Random
import ml.dmlc.xgboost4j.java.Rabit
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostGeneralSuite extends FunSuite with PerTest {
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
@@ -227,26 +229,45 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
def buildDenseDataFrame(): DataFrame = {
val numRows = 100
val numCols = 5
val data = (0 until numRows).map { x =>
val label = Random.nextInt(2)
val values = Array.tabulate[Double](numCols) { c =>
if (c == numCols - 1) -0.1 else Random.nextDouble
if (c == numCols - 1) 0 else Random.nextDouble
}
(label, Vectors.dense(values))
}
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
}
val denseDF = buildDenseDataFrame().repartition(4)
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
val model = new XGBoostClassifier(paramMap).fit(denseDF)
model.transform(denseDF).collect()
}
test("handle Float.NaN as missing value correctly") {
val spark = ss
import spark.implicits._
val testDF = Seq(
(1.0f, 0.0f, Float.NaN, 1.0),
(1.0f, 0.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0),
(1.0f, Float.NaN, 0.0f, 0.0),
(0.0f, 0.0f, 0.0f, 0.0),
(0.0f, 1.0f, 0.0f, 1.0),
(Float.NaN, 0.0f, 0.0f, 1.0)
).toDF("col1", "col2", "col3", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}
test("training with spark parallelism checks disabled") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)