[jvm-packages] the current version of xgboost does not consider missing value in prediction (#3529)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* consider missing value in prediction

* handle single prediction instance

* fix type conversion
This commit is contained in:
Nan Zhu 2018-07-30 14:16:24 -07:00 committed by GitHub
parent 3b62e75f2e
commit b546321c83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 12 deletions

View File

@ -55,11 +55,11 @@ object TrackerConf {
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private def removeMissingValues(
denseLabeledPoints: Iterator[XGBLabeledPoint],
private[spark] def removeMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
denseLabeledPoints.map { labeledPoint =>
xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) {
@ -69,7 +69,7 @@ object XGBoost extends Serializable {
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
} else {
denseLabeledPoints
xgbLabelPoints
}
}

View File

@ -241,7 +241,7 @@ class XGBoostClassificationModel private[ml](
*/
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(Iterator(features.asXGB))
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
val probability = _booster.predict(data = dm)(0)
if (numClasses == 2) {
math.round(probability(0))
@ -272,12 +272,12 @@ class XGBoostClassificationModel private[ml](
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.rdd.mapPartitions { rowIterator =>
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator
import DataUtils._
val cacheInfo = {
@ -288,7 +288,9 @@ class XGBoostClassificationModel private[ml](
}
}
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try {
val rawPredictionItr = {
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator

View File

@ -237,7 +237,7 @@ class XGBoostRegressionModel private[ml] (
*/
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(Iterator(features.asXGB))
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
_booster.predict(data = dm)(0)(0)
}
@ -250,12 +250,12 @@ class XGBoostRegressionModel private[ml] (
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.rdd.mapPartitions { rowIterator =>
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator
import DataUtils._
val cacheInfo = {
@ -266,7 +266,9 @@ class XGBoostRegressionModel private[ml] (
}
}
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try {
val originalPredictionItr = {
bBooster.value.predict(dm).map(Row(_)).iterator