[jvm-packages] allow training with missing values in xgboost-spark (#1525)
* allow training with missing values in xgboost-spark * fix compilation error * fix bug
This commit is contained in:
parent
6014839961
commit
3f198b9fef
@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.mllib.linalg.SparseVector
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
@ -35,12 +37,37 @@ object XGBoost extends Serializable {
|
||||
new XGBoostModel(booster)
|
||||
}
|
||||
|
||||
private def fromDenseToSparseLabeledPoints(
|
||||
denseLabeledPoints: Iterator[LabeledPoint],
|
||||
missing: Float): Iterator[LabeledPoint] = {
|
||||
if (!missing.isNaN) {
|
||||
val sparseLabeledPoints = new ListBuffer[LabeledPoint]
|
||||
for (labelPoint <- denseLabeledPoints) {
|
||||
val dVector = labelPoint.features.toDense
|
||||
val indices = new ListBuffer[Int]
|
||||
val values = new ListBuffer[Double]
|
||||
for (i <- dVector.values.indices) {
|
||||
if (values(i) != missing) {
|
||||
indices += i
|
||||
values += dVector.values(i)
|
||||
}
|
||||
}
|
||||
val sparseVector = new SparseVector(dVector.values.length, indices.toArray,
|
||||
values.toArray)
|
||||
sparseLabeledPoints += LabeledPoint(labelPoint.label, sparseVector)
|
||||
}
|
||||
sparseLabeledPoints.iterator
|
||||
} else {
|
||||
denseLabeledPoints
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingData: RDD[LabeledPoint],
|
||||
xgBoostConfMap: Map[String, Any],
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||
useExternalMemory: Boolean): RDD[Booster] = {
|
||||
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
val partitionedData = {
|
||||
if (numWorkers > trainingData.partitions.length) {
|
||||
@ -71,7 +98,8 @@ object XGBoost extends Serializable {
|
||||
null
|
||||
}
|
||||
}
|
||||
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
|
||||
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
||||
val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
||||
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||
watches = new mutable.HashMap[String, DMatrix] {
|
||||
put("train", trainingSet)
|
||||
@ -97,13 +125,14 @@ object XGBoost extends Serializable {
|
||||
* @param eval the user-defined evaluation function, null by default
|
||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||
* @param missing the value represented the missing value in the dataset
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||
* @return XGBoostModel when successful training
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||
useExternalMemory: Boolean = false): XGBoostModel = {
|
||||
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
val tracker = new RabitTracker(nWorkers)
|
||||
implicit val sc = trainingData.sparkContext
|
||||
@ -119,7 +148,7 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory)
|
||||
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
|
||||
@ -128,7 +128,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new scala.collection.mutable.HashMap[String, String],
|
||||
numWorkers = 2, round = 5, null, null, useExternalMemory = false)
|
||||
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = false)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === 2)
|
||||
val boosters = boosterRDD.collect()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user