[jvm-packages] the current version of xgboost does not consider missing value in prediction (#3529)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * consider missing value in prediction * handle single prediction instance * fix type conversion
This commit is contained in:
parent
3b62e75f2e
commit
b546321c83
@ -55,11 +55,11 @@ object TrackerConf {
|
||||
object XGBoost extends Serializable {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
private def removeMissingValues(
|
||||
denseLabeledPoints: Iterator[XGBLabeledPoint],
|
||||
private[spark] def removeMissingValues(
|
||||
xgbLabelPoints: Iterator[XGBLabeledPoint],
|
||||
missing: Float): Iterator[XGBLabeledPoint] = {
|
||||
if (!missing.isNaN) {
|
||||
denseLabeledPoints.map { labeledPoint =>
|
||||
xgbLabelPoints.map { labeledPoint =>
|
||||
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
|
||||
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
|
||||
for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) {
|
||||
@ -69,7 +69,7 @@ object XGBoost extends Serializable {
|
||||
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
|
||||
}
|
||||
} else {
|
||||
denseLabeledPoints
|
||||
xgbLabelPoints
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -241,7 +241,7 @@ class XGBoostClassificationModel private[ml](
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(Iterator(features.asXGB))
|
||||
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
||||
val probability = _booster.predict(data = dm)(0)
|
||||
if (numClasses == 2) {
|
||||
math.round(probability(0))
|
||||
@ -272,12 +272,12 @@ class XGBoostClassificationModel private[ml](
|
||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = dataset.sparkSession.sparkContext.appName
|
||||
|
||||
val rdd = dataset.rdd.mapPartitions { rowIterator =>
|
||||
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
|
||||
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cacheInfo = {
|
||||
@ -288,7 +288,9 @@ class XGBoostClassificationModel private[ml](
|
||||
}
|
||||
}
|
||||
|
||||
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
|
||||
val dm = new DMatrix(
|
||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
cacheInfo)
|
||||
try {
|
||||
val rawPredictionItr = {
|
||||
bBooster.value.predict(dm, outPutMargin = true).map(Row(_)).iterator
|
||||
|
||||
@ -237,7 +237,7 @@ class XGBoostRegressionModel private[ml] (
|
||||
*/
|
||||
override def predict(features: Vector): Double = {
|
||||
import DataUtils._
|
||||
val dm = new DMatrix(Iterator(features.asXGB))
|
||||
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
|
||||
_booster.predict(data = dm)(0)(0)
|
||||
}
|
||||
|
||||
@ -250,12 +250,12 @@ class XGBoostRegressionModel private[ml] (
|
||||
val bBooster = dataset.sparkSession.sparkContext.broadcast(_booster)
|
||||
val appName = dataset.sparkSession.sparkContext.appName
|
||||
|
||||
val rdd = dataset.rdd.mapPartitions { rowIterator =>
|
||||
val rdd = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator =>
|
||||
if (rowIterator.hasNext) {
|
||||
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val (rowItr1, rowItr2) = rowIterator.duplicate
|
||||
val featuresIterator = rowItr2.map(row => row.asInstanceOf[Row].getAs[Vector](
|
||||
val featuresIterator = rowItr2.map(row => row.getAs[Vector](
|
||||
$(featuresCol))).toList.iterator
|
||||
import DataUtils._
|
||||
val cacheInfo = {
|
||||
@ -266,7 +266,9 @@ class XGBoostRegressionModel private[ml] (
|
||||
}
|
||||
}
|
||||
|
||||
val dm = new DMatrix(featuresIterator.map(_.asXGB), cacheInfo)
|
||||
val dm = new DMatrix(
|
||||
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
|
||||
cacheInfo)
|
||||
try {
|
||||
val originalPredictionItr = {
|
||||
bBooster.value.predict(dm).map(Row(_)).iterator
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user