[BREAKING][jvm-packages] fix the non-zero missing value handling (#4349)

* fix the nan and non-zero missing value handling

* fix nan handling part

* add missing value

* Update MissingValueHandlingSuite.scala

* Update MissingValueHandlingSuite.scala

* stylistic fix
This commit is contained in:
Nan Zhu
2019-04-26 11:10:33 -07:00
committed by GitHub
parent 2d875ec019
commit 995698b0cb
5 changed files with 201 additions and 88 deletions

View File

@@ -18,9 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
import java.util.Properties
import scala.collection.mutable.ListBuffer
import scala.collection.{AbstractIterator, mutable}
import scala.util.Random
@@ -32,8 +30,8 @@ import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkException, SparkParallelismTracker, TaskContext}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
@@ -75,8 +73,9 @@ object XGBoost extends Serializable {
if (missing != 0.0f) {
xgbLabelPoints.map(labeledPoint => {
if (labeledPoint.indices != null) {
throw new RuntimeException("you can only specify missing value as 0.0 when you have" +
" SparseVector as your feature format")
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
" format")
}
labeledPoint
})
@@ -107,7 +106,8 @@ object XGBoost extends Serializable {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
missing, (v: Float) => v != missing)
} else {
removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN)
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
missing, (v: Float) => !v.isNaN)
}
}