[BREAKING][jvm-packages] fix the non-zero missing value handling (#4349)
* fix the nan and non-zero missing value handling * fix nan handling part * add missing value * Update MissingValueHandlingSuite.scala * Update MissingValueHandlingSuite.scala * stylistic fix
This commit is contained in:
@@ -18,9 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.Files
|
||||
import java.util.Properties
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
import scala.util.Random
|
||||
|
||||
@@ -32,8 +30,8 @@ import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkException, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
|
||||
@@ -75,8 +73,9 @@ object XGBoost extends Serializable {
|
||||
if (missing != 0.0f) {
|
||||
xgbLabelPoints.map(labeledPoint => {
|
||||
if (labeledPoint.indices != null) {
|
||||
throw new RuntimeException("you can only specify missing value as 0.0 when you have" +
|
||||
" SparseVector as your feature format")
|
||||
throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
|
||||
s" set value $missing) when you have SparseVector or Empty vector as your feature" +
|
||||
" format")
|
||||
}
|
||||
labeledPoint
|
||||
})
|
||||
@@ -107,7 +106,8 @@ object XGBoost extends Serializable {
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
||||
missing, (v: Float) => v != missing)
|
||||
} else {
|
||||
removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN)
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
||||
missing, (v: Float) => !v.isNaN)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user