[jvm-packages] the current version of xgboost does not consider missing value in prediction (#3529)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* consider missing value in prediction

* handle single prediction instance

* fix type conversion
This commit is contained in:
Nan Zhu 2018-07-30 14:16:24 -07:00 committed by GitHub
parent 3b62e75f2e
commit b546321c83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 12 deletions

View File

@ -55,11 +55,11 @@ object TrackerConf {
object XGBoost extends Serializable { object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark") private val logger = LogFactory.getLog("XGBoostSpark")
private def removeMissingValues( private[spark] def removeMissingValues(
denseLabeledPoints: Iterator[XGBLabeledPoint], xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float): Iterator[XGBLabeledPoint] = { missing: Float): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) { if (!missing.isNaN) {
denseLabeledPoints.map { labeledPoint => xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt() val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat() val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) { for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) {
@ -69,7 +69,7 @@ object XGBoost extends Serializable {
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result()) labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
} }
} else { } else {
denseLabeledPoints xgbLabelPoints
} }
} }

View File

@ -241,7 +241,7 @@ class XGBoostClassificationModel private[ml](
*/ */
override def predict(features: Vector): Double = { override def predict(features: Vector): Double = {
import DataUtils._ import DataUtils._
val dm = new DMatrix(Iterator(features.asXGB)) val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
val probability = _booster.predict(data = dm)(0) val probability = _booster.predict(data = dm)(0)
if (numClasses == 2) { if (numClasses == 2) {
math.round(probability(0)) math.round(probability(0))
@ -272,12 +272,12 @@ class XGBoostClassificationModel private[ml](
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.rdd.mapPartitions { rowIterator => val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) { if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector]( val featuresIterator = rowItr2.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator $(featuresCol))).toList.iterator
import DataUtils._ import DataUtils._
val cacheInfo = { val cacheInfo = {
@ -288,7 +288,9 @@ class XGBoostClassificationModel private[ml](
} }
} }
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo) val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try { try {
val rawPredictionItr = { val rawPredictionItr = {
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator

View File

@ -237,7 +237,7 @@ class XGBoostRegressionModel private[ml] (
*/ */
override def predict(features: Vector): Double = { override def predict(features: Vector): Double = {
import DataUtils._ import DataUtils._
val dm = new DMatrix(Iterator(features.asXGB)) val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
_booster.predict(data = dm)(0)(0) _booster.predict(data = dm)(0)(0)
} }
@ -250,12 +250,12 @@ class XGBoostRegressionModel private[ml] (
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster) val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
val appName = dataset.sparkSession.sparkContext.appName val appName = dataset.sparkSession.sparkContext.appName
val rdd = dataset.rdd.mapPartitions { rowIterator => val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
if (rowIterator.hasNext) { if (rowIterator.hasNext) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Rabit.init(rabitEnv.asJava) Rabit.init(rabitEnv.asJava)
val (rowItr1, rowItr2) = rowIterator.duplicate val (rowItr1, rowItr2) = rowIterator.duplicate
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector]( val featuresIterator = rowItr2.map(row => row.getAs[Vector](
$(featuresCol))).toList.iterator $(featuresCol))).toList.iterator
import DataUtils._ import DataUtils._
val cacheInfo = { val cacheInfo = {
@ -266,7 +266,9 @@ class XGBoostRegressionModel private[ml] (
} }
} }
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo) val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try { try {
val originalPredictionItr = { val originalPredictionItr = {
bBooster.value.predict(dm).map(Row(_)).iterator bBooster.value.predict(dm).map(Row(_)).iterator