From 3f198b9fefa0480f46b7dbc9c0d8564c5e95c542 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Mon, 29 Aug 2016 21:45:49 -0400 Subject: [PATCH] [jvm-packages] allow training with missing values in xgboost-spark (#1525) * allow training with missing values in xgboost-spark * fix compilation error * fix bug --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 37 +++++++++++++++++-- .../xgboost4j/scala/spark/XGBoostSuite.scala | 2 +- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7cc5f7658..6cbfbf72c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.mutable.ListBuffer import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.commons.logging.LogFactory import org.apache.hadoop.fs.Path +import org.apache.spark.mllib.linalg.SparseVector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, TaskContext} @@ -35,12 +37,37 @@ object XGBoost extends Serializable { new XGBoostModel(booster) } + private def fromDenseToSparseLabeledPoints( + denseLabeledPoints: Iterator[LabeledPoint], + missing: Float): Iterator[LabeledPoint] = { + if (!missing.isNaN) { + val sparseLabeledPoints = new ListBuffer[LabeledPoint] + for (labelPoint <- denseLabeledPoints) { + val dVector = labelPoint.features.toDense + val indices = new ListBuffer[Int] + val values = new ListBuffer[Double] + for (i <- dVector.values.indices) { + if (values(i) != missing) { + indices += i + values += dVector.values(i) + } + } + val sparseVector = new SparseVector(dVector.values.length, indices.toArray, + values.toArray) + sparseLabeledPoints += LabeledPoint(labelPoint.label, sparseVector) + } + sparseLabeledPoints.iterator + } else { + denseLabeledPoints + } + } + private[spark] def buildDistributedBoosters( trainingData: RDD[LabeledPoint], xgBoostConfMap: Map[String, Any], rabitEnv: mutable.Map[String, String], numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait, - useExternalMemory: Boolean): RDD[Booster] = { + useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = { import DataUtils._ val partitionedData = { if (numWorkers > trainingData.partitions.length) { @@ -71,7 +98,8 @@ object XGBoost extends Serializable { null } } - val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName)) + val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing) + val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName)) booster = SXGBoost.train(trainingSet, xgBoostConfMap, round, watches = new mutable.HashMap[String, DMatrix] { put("train", trainingSet) @@ -97,13 +125,14 @@ object XGBoost extends Serializable { * @param eval the user-defined evaluation function, null by default * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as * true, the user may save the RAM cost for running XGBoost within Spark + * @param missing the value represented the missing value in the dataset * @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed * @return XGBoostModel when successful training */ @throws(classOf[XGBoostError]) def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int, nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null, - useExternalMemory: Boolean = false): XGBoostModel = { + useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = { require(nWorkers > 0, "you must specify more than 0 workers") val tracker = new RabitTracker(nWorkers) implicit val sc = trainingData.sparkContext @@ -119,7 +148,7 @@ object XGBoost extends Serializable { } require(tracker.start(), "FAULT: Failed to start tracker") val boosters = buildDistributedBoosters(trainingData, overridedConfMap, - tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory) + tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing) val sparkJobThread = new Thread() { override def run() { // force the job diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index 2b6131546..639a19c91 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -128,7 +128,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter { List("eta" -> "1", "max_depth" -> "6", "silent" -> "0", "objective" -> "binary:logistic").toMap, new scala.collection.mutable.HashMap[String, String], - numWorkers = 2, round = 5, null, null, useExternalMemory = false) + numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = false) val boosterCount = boosterRDD.count() assert(boosterCount === 2) val boosters = boosterRDD.collect()