[jvm-packages] handle NaN as missing value explicitly (#4309)
* handle nan * handle nan explicitly * make code better and handle sparse vector in spark * Update XGBoostGeneralSuite.scala
This commit is contained in:
parent
7ea5b772fb
commit
ad4de0d718
@ -70,30 +70,53 @@ private[spark] case class XGBLabeledPointGroup(
|
|||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
private[spark] def removeMissingValues(
|
private def verifyMissingSetting(xgbLabelPoints: Iterator[XGBLabeledPoint], missing: Float):
|
||||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
Iterator[XGBLabeledPoint] = {
|
||||||
missing: Float): Iterator[XGBLabeledPoint] = {
|
if (missing != 0.0f) {
|
||||||
if (!missing.isNaN) {
|
xgbLabelPoints.map(labeledPoint => {
|
||||||
xgbLabelPoints.map { labeledPoint =>
|
if (labeledPoint.indices != null) {
|
||||||
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
|
throw new RuntimeException("you can only specify missing value as 0.0 when you have" +
|
||||||
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
|
" SparseVector as your feature format")
|
||||||
for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) {
|
|
||||||
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
|
|
||||||
valuesBuilder += value
|
|
||||||
}
|
}
|
||||||
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
|
labeledPoint
|
||||||
}
|
})
|
||||||
} else {
|
} else {
|
||||||
xgbLabelPoints
|
xgbLabelPoints
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def removeMissingValuesWithGroup(
|
private def removeMissingValues(
|
||||||
|
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||||
|
missing: Float,
|
||||||
|
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
|
||||||
|
xgbLabelPoints.map { labeledPoint =>
|
||||||
|
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
|
||||||
|
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
|
||||||
|
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
|
||||||
|
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
|
||||||
|
valuesBuilder += value
|
||||||
|
}
|
||||||
|
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] def processMissingValues(
|
||||||
|
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||||
|
missing: Float): Iterator[XGBLabeledPoint] = {
|
||||||
|
if (!missing.isNaN) {
|
||||||
|
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
|
||||||
|
missing, (v: Float) => v != missing)
|
||||||
|
} else {
|
||||||
|
removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def processMissingValuesWithGroup(
|
||||||
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
|
||||||
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
|
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
|
||||||
if (!missing.isNaN) {
|
if (!missing.isNaN) {
|
||||||
xgbLabelPointGroups.map {
|
xgbLabelPointGroups.map {
|
||||||
labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
|
labeledPoints => XGBoost.processMissingValues(labeledPoints.iterator, missing).toArray
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
xgbLabelPointGroups
|
xgbLabelPointGroups
|
||||||
@ -310,7 +333,7 @@ object XGBoost extends Serializable {
|
|||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPoints => {
|
trainingData.mapPartitions(labeledPoints => {
|
||||||
val watches = Watches.buildWatches(params,
|
val watches = Watches.buildWatches(params,
|
||||||
removeMissingValues(labeledPoints, missing),
|
processMissingValues(labeledPoints, missing),
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||||
obj, eval, prevBooster)
|
obj, eval, prevBooster)
|
||||||
@ -320,7 +343,7 @@ object XGBoost extends Serializable {
|
|||||||
nameAndLabeledPointSets =>
|
nameAndLabeledPointSets =>
|
||||||
val watches = Watches.buildWatches(
|
val watches = Watches.buildWatches(
|
||||||
nameAndLabeledPointSets.map {
|
nameAndLabeledPointSets.map {
|
||||||
case (name, iter) => (name, removeMissingValues(iter, missing))},
|
case (name, iter) => (name, processMissingValues(iter, missing))},
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
|
||||||
obj, eval, prevBooster)
|
obj, eval, prevBooster)
|
||||||
@ -340,7 +363,7 @@ object XGBoost extends Serializable {
|
|||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
trainingData.mapPartitions(labeledPointGroups => {
|
trainingData.mapPartitions(labeledPointGroups => {
|
||||||
val watches = Watches.buildWatchesWithGroup(params,
|
val watches = Watches.buildWatchesWithGroup(params,
|
||||||
removeMissingValuesWithGroup(labeledPointGroups, missing),
|
processMissingValuesWithGroup(labeledPointGroups, missing),
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
@ -349,7 +372,7 @@ object XGBoost extends Serializable {
|
|||||||
labeledPointGroupSets => {
|
labeledPointGroupSets => {
|
||||||
val watches = Watches.buildWatchesWithGroup(
|
val watches = Watches.buildWatchesWithGroup(
|
||||||
labeledPointGroupSets.map {
|
labeledPointGroupSets.map {
|
||||||
case (name, iter) => (name, removeMissingValuesWithGroup(iter, missing))
|
case (name, iter) => (name, processMissingValuesWithGroup(iter, missing))
|
||||||
},
|
},
|
||||||
getCacheDirName(useExternalMemory))
|
getCacheDirName(useExternalMemory))
|
||||||
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,
|
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,
|
||||||
|
|||||||
@ -256,7 +256,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
*/
|
*/
|
||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
|
||||||
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
|
||||||
if (numClasses == 2) {
|
if (numClasses == 2) {
|
||||||
math.round(probability(0))
|
math.round(probability(0))
|
||||||
@ -303,7 +303,7 @@ class XGBoostClassificationModel private[ml](
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dm = new DMatrix(
|
val dm = new DMatrix(
|
||||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
try {
|
try {
|
||||||
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
|
||||||
|
|||||||
@ -247,7 +247,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
*/
|
*/
|
||||||
override def predict(features: Vector): Double = {
|
override def predict(features: Vector): Double = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
|
||||||
_booster.predict(data = dm)(0)(0)
|
_booster.predict(data = dm)(0)(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ class XGBoostRegressionModel private[ml] (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
val dm = new DMatrix(
|
val dm = new DMatrix(
|
||||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||||
cacheInfo)
|
cacheInfo)
|
||||||
try {
|
try {
|
||||||
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
val Array(originalPredictionItr, predLeafItr, predContribItr) =
|
||||||
|
|||||||
@ -33,6 +33,8 @@ import scala.util.Random
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
|
|
||||||
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
|
||||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||||
@ -227,26 +229,45 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
def buildDenseDataFrame(): DataFrame = {
|
def buildDenseDataFrame(): DataFrame = {
|
||||||
val numRows = 100
|
val numRows = 100
|
||||||
val numCols = 5
|
val numCols = 5
|
||||||
|
|
||||||
val data = (0 until numRows).map { x =>
|
val data = (0 until numRows).map { x =>
|
||||||
val label = Random.nextInt(2)
|
val label = Random.nextInt(2)
|
||||||
val values = Array.tabulate[Double](numCols) { c =>
|
val values = Array.tabulate[Double](numCols) { c =>
|
||||||
if (c == numCols - 1) -0.1 else Random.nextDouble
|
if (c == numCols - 1) 0 else Random.nextDouble
|
||||||
}
|
}
|
||||||
|
|
||||||
(label, Vectors.dense(values))
|
(label, Vectors.dense(values))
|
||||||
}
|
}
|
||||||
|
|
||||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
||||||
}
|
}
|
||||||
|
|
||||||
val denseDF = buildDenseDataFrame().repartition(4)
|
val denseDF = buildDenseDataFrame().repartition(4)
|
||||||
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
|
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
|
||||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
||||||
model.transform(denseDF).collect()
|
model.transform(denseDF).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("handle Float.NaN as missing value correctly") {
|
||||||
|
val spark = ss
|
||||||
|
import spark.implicits._
|
||||||
|
val testDF = Seq(
|
||||||
|
(1.0f, 0.0f, Float.NaN, 1.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 0.0),
|
||||||
|
(1.0f, 0.0f, 1.0f, 1.0),
|
||||||
|
(1.0f, Float.NaN, 0.0f, 0.0),
|
||||||
|
(0.0f, 0.0f, 0.0f, 0.0),
|
||||||
|
(0.0f, 1.0f, 0.0f, 1.0),
|
||||||
|
(Float.NaN, 0.0f, 0.0f, 1.0)
|
||||||
|
).toDF("col1", "col2", "col3", "label")
|
||||||
|
val vectorAssembler = new VectorAssembler()
|
||||||
|
.setInputCols(Array("col1", "col2", "col3"))
|
||||||
|
.setOutputCol("features")
|
||||||
|
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2",
|
||||||
|
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
|
||||||
|
val model = new XGBoostClassifier(paramMap).fit(inputDF)
|
||||||
|
model.transform(inputDF).collect()
|
||||||
|
}
|
||||||
|
|
||||||
test("training with spark parallelism checks disabled") {
|
test("training with spark parallelism checks disabled") {
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user