[jvm-packages] handle NaN as missing value explicitly (#4309)
* handle nan * handle nan explicitly * make code better and handle sparse vector in spark * Update XGBoostGeneralSuite.scala
This commit is contained in:
parent
7ea5b772fb
commit
ad4de0d718
@ -70,30 +70,53 @@ private[spark] case class XGBLabeledPointGroup(
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private[spark] def removeMissingValues(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float): Iterator[XGBLabeledPoint] = {
|
||||
if (!missing.isNaN) {
|
||||
xgbLabelPoints.map { labeledPoint =>
|
||||
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
|
||||
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
|
||||
for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) {
|
||||
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
|
||||
valuesBuilder += value
|
||||
}
|
||||
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
|
||||
private def verifyMissingSetting(xgbLabelPoints: Iterator[XGBLabeledPoint], missing: Float):
|
||||
Iterator[XGBLabeledPoint] = {
|
||||
if (missing != 0.0f) {
|
||||
xgbLabelPoints.map(labeledPoint => {
|
||||
if (labeledPoint.indices != null) {
|
||||
throw new RuntimeException("you can only specify missing value as 0.0 when you have" +
|
||||
" SparseVector as your feature format")
|
||||
}
|
||||
labeledPoint
|
||||
})
|
||||
} else {
|
||||
xgbLabelPoints
|
||||
}
|
||||
}
|
||||
|
||||
private def removeMissingValuesWithGroup(
|
||||
private def removeMissingValues(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float,
|
||||
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
|
||||
xgbLabelPoints.map { labeledPoint =>
|
||||
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
|
||||
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
|
||||
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
|
||||
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
|
||||
valuesBuilder += value
|
||||
}
|
||||
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def processMissingValues(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float): Iterator[XGBLabeledPoint] = {
|
||||
if (!missing.isNaN) {
|
||||
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
||||
missing, (v: Float) => v != missing)
|
||||
} else {
|
||||
removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN)
|
||||
}
|
||||
}
|
||||
|
||||
private def processMissingValuesWithGroup(
|
||||
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
|
||||
if (!missing.isNaN) {
|
||||
xgbLabelPointGroups.map {
|
||||
labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
|
||||
labeledPoints => XGBoost.processMissingValues(labeledPoints.iterator, missing).toArray
|
||||
}
|
||||
} else {
|
||||
xgbLabelPointGroups
|
||||
@ -310,7 +333,7 @@ object XGBoost extends Serializable {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPoints => {
|
||||
val watches = Watches.buildWatches(params,
|
||||
removeMissingValues(labeledPoints, missing),
|
||||
processMissingValues(labeledPoints, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
@ -320,7 +343,7 @@ object XGBoost extends Serializable {
|
||||
nameAndLabeledPointSets =>
|
||||
val watches = Watches.buildWatches(
|
||||
nameAndLabeledPointSets.map {
|
||||
case (name, iter) => (name, removeMissingValues(iter, missing))},
|
||||
case (name, iter) => (name, processMissingValues(iter, missing))},
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||
obj, eval, prevBooster)
|
||||
@ -340,7 +363,7 @@ object XGBoost extends Serializable {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
trainingData.mapPartitions(labeledPointGroups => {
|
||||
val watches = Watches.buildWatchesWithGroup(params,
|
||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
processMissingValuesWithGroup(labeledPointGroups, missing),
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
||||
}).cache()
|
||||
@ -349,7 +372,7 @@ object XGBoost extends Serializable {
|
||||
labeledPointGroupSets => {
|
||||
val watches = Watches.buildWatchesWithGroup(
|
||||
labeledPointGroupSets.map {
|
||||
case (name, iter) => (name, removeMissingValuesWithGroup(iter, missing))
|
||||
case (name, iter) => (name, processMissingValuesWithGroup(iter, missing))
|
||||
},
|
||||
getCacheDirName(useExternalMemory))
|
||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,
|
||||
|
||||
@ -256,7 +256,7 @@ class XGBoostClassificationModel private[ml](
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
||||
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
|
||||
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
||||
if (numClasses == 2) {
|
||||
math.round(probability(0))
|
||||
@ -303,7 +303,7 @@ class XGBoostClassificationModel private[ml](
|
||||
}
|
||||
}
|
||||
val dm = new DMatrix(
|
||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
cacheInfo)
|
||||
try {
|
||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||
|
||||
@ -247,7 +247,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
||||
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
|
||||
_booster.predict(data = dm)(0)(0)
|
||||
}
|
||||
|
||||
@ -275,7 +275,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
}
|
||||
}
|
||||
val dm = new DMatrix(
|
||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
cacheInfo)
|
||||
try {
|
||||
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
||||
|
||||
@ -33,6 +33,8 @@ import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
@ -227,26 +229,45 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
def buildDenseDataFrame(): DataFrame = {
|
||||
val numRows = 100
|
||||
val numCols = 5
|
||||
|
||||
val data = (0 until numRows).map { x =>
|
||||
val label = Random.nextInt(2)
|
||||
val values = Array.tabulate[Double](numCols) { c =>
|
||||
if (c == numCols - 1) -0.1 else Random.nextDouble
|
||||
if (c == numCols - 1) 0 else Random.nextDouble
|
||||
}
|
||||
|
||||
(label, Vectors.dense(values))
|
||||
}
|
||||
|
||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
||||
}
|
||||
|
||||
val denseDF = buildDenseDataFrame().repartition(4)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
|
||||
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
||||
model.transform(denseDF).collect()
|
||||
}
|
||||
|
||||
test("handle Float.NaN as missing value correctly") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
val testDF = Seq(
|
||||
(1.0f, 0.0f, Float.NaN, 1.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(0.0f, 1.0f, 0.0f, 0.0),
|
||||
(1.0f, 0.0f, 1.0f, 1.0),
|
||||
(1.0f, Float.NaN, 0.0f, 0.0),
|
||||
(0.0f, 0.0f, 0.0f, 0.0),
|
||||
(0.0f, 1.0f, 0.0f, 1.0),
|
||||
(Float.NaN, 0.0f, 0.0f, 1.0)
|
||||
).toDF("col1", "col2", "col3", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(Array("col1", "col2", "col3"))
|
||||
.setOutputCol("features")
|
||||
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||
model.transform(inputDF).collect()
|
||||
}
|
||||
|
||||
test("training with spark parallelism checks disabled") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user