[jvm-packages] handle NaN as missing value explicitly (#4309)

* handle nan

* handle nan explicitly

* make code better and handle sparse vector in spark

* Update XGBoostGeneralSuite.scala
This commit is contained in:
Nan Zhu
2019-03-30 19:34:26 +08:00
committed by GitHub
parent 7ea5b772fb
commit ad4de0d718
4 changed files with 72 additions and 28 deletions

View File

@@ -70,30 +70,53 @@ private[spark] case class XGBLabeledPointGroup(
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private[spark] def removeMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) {
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
valuesBuilder += value
private def verifyMissingSetting(xgbLabelPoints: Iterator[XGBLabeledPoint], missing: Float):
Iterator[XGBLabeledPoint] = {
if (missing != 0.0f) {
xgbLabelPoints.map(labeledPoint => {
if (labeledPoint.indices != null) {
throw new RuntimeException("you can only specify missing value as 0.0 when you have" +
" SparseVector as your feature format")
}
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
labeledPoint
})
} else {
xgbLabelPoints
}
}
private def removeMissingValuesWithGroup(
private def removeMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
valuesBuilder += value
}
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
}
private[spark] def processMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
missing, (v: Float) => v != missing)
} else {
removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN)
}
}
private def processMissingValuesWithGroup(
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
if (!missing.isNaN) {
xgbLabelPointGroups.map {
labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
labeledPoints => XGBoost.processMissingValues(labeledPoints.iterator, missing).toArray
}
} else {
xgbLabelPointGroups
@@ -310,7 +333,7 @@ object XGBoost extends Serializable {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(params,
removeMissingValues(labeledPoints, missing),
processMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
@@ -320,7 +343,7 @@ object XGBoost extends Serializable {
nameAndLabeledPointSets =>
val watches = Watches.buildWatches(
nameAndLabeledPointSets.map {
case (name, iter) => (name, removeMissingValues(iter, missing))},
case (name, iter) => (name, processMissingValues(iter, missing))},
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
@@ -340,7 +363,7 @@ object XGBoost extends Serializable {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(params,
removeMissingValuesWithGroup(labeledPointGroups, missing),
processMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
}).cache()
@@ -349,7 +372,7 @@ object XGBoost extends Serializable {
labeledPointGroupSets => {
val watches = Watches.buildWatchesWithGroup(
labeledPointGroupSets.map {
case (name, iter) => (name, removeMissingValuesWithGroup(iter, missing))
case (name, iter) => (name, processMissingValuesWithGroup(iter, missing))
},
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,

View File

@@ -256,7 +256,7 @@ class XGBoostClassificationModel private[ml](
*/
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
if (numClasses == 2) {
math.round(probability(0))
@@ -303,7 +303,7 @@ class XGBoostClassificationModel private[ml](
}
}
val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =

View File

@@ -247,7 +247,7 @@ class XGBoostRegressionModel private[ml] (
*/
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
_booster.predict(data = dm)(0)(0)
}
@@ -275,7 +275,7 @@ class XGBoostRegressionModel private[ml] (
}
}
val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try {
val Array(originalPredictionItr, predLeafItr, predContribItr) =