[jvm-packages] allow training with missing values in xgboost-spark (#1525)
* allow training with missing values in xgboost-spark * fix compilation error * fix bug
This commit is contained in:
parent
6014839961
commit
3f198b9fef
@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.apache.spark.mllib.linalg.SparseVector
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
@ -35,12 +37,37 @@ object XGBoost extends Serializable {
|
|||||||
new XGBoostModel(booster)
|
new XGBoostModel(booster)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def fromDenseToSparseLabeledPoints(
|
||||||
|
denseLabeledPoints: Iterator[LabeledPoint],
|
||||||
|
missing: Float): Iterator[LabeledPoint] = {
|
||||||
|
if (!missing.isNaN) {
|
||||||
|
val sparseLabeledPoints = new ListBuffer[LabeledPoint]
|
||||||
|
for (labelPoint <- denseLabeledPoints) {
|
||||||
|
val dVector = labelPoint.features.toDense
|
||||||
|
val indices = new ListBuffer[Int]
|
||||||
|
val values = new ListBuffer[Double]
|
||||||
|
for (i <- dVector.values.indices) {
|
||||||
|
if (values(i) != missing) {
|
||||||
|
indices += i
|
||||||
|
values += dVector.values(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val sparseVector = new SparseVector(dVector.values.length, indices.toArray,
|
||||||
|
values.toArray)
|
||||||
|
sparseLabeledPoints += LabeledPoint(labelPoint.label, sparseVector)
|
||||||
|
}
|
||||||
|
sparseLabeledPoints.iterator
|
||||||
|
} else {
|
||||||
|
denseLabeledPoints
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private[spark] def buildDistributedBoosters(
|
private[spark] def buildDistributedBoosters(
|
||||||
trainingData: RDD[LabeledPoint],
|
trainingData: RDD[LabeledPoint],
|
||||||
xgBoostConfMap: Map[String, Any],
|
xgBoostConfMap: Map[String, Any],
|
||||||
rabitEnv: mutable.Map[String, String],
|
rabitEnv: mutable.Map[String, String],
|
||||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||||
useExternalMemory: Boolean): RDD[Booster] = {
|
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val partitionedData = {
|
val partitionedData = {
|
||||||
if (numWorkers > trainingData.partitions.length) {
|
if (numWorkers > trainingData.partitions.length) {
|
||||||
@ -71,7 +98,8 @@ object XGBoost extends Serializable {
|
|||||||
null
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
|
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
|
||||||
|
val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
|
||||||
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||||
watches = new mutable.HashMap[String, DMatrix] {
|
watches = new mutable.HashMap[String, DMatrix] {
|
||||||
put("train", trainingSet)
|
put("train", trainingSet)
|
||||||
@ -97,13 +125,14 @@ object XGBoost extends Serializable {
|
|||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval the user-defined evaluation function, null by default
|
||||||
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
* true, the user may save the RAM cost for running XGBoost within Spark
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
|
* @param missing the value represented the missing value in the dataset
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||||
* @return XGBoostModel when successful training
|
* @return XGBoostModel when successful training
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||||
useExternalMemory: Boolean = false): XGBoostModel = {
|
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
val tracker = new RabitTracker(nWorkers)
|
val tracker = new RabitTracker(nWorkers)
|
||||||
implicit val sc = trainingData.sparkContext
|
implicit val sc = trainingData.sparkContext
|
||||||
@ -119,7 +148,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||||
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory)
|
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
// force the job
|
// force the job
|
||||||
|
|||||||
@ -128,7 +128,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
new scala.collection.mutable.HashMap[String, String],
|
new scala.collection.mutable.HashMap[String, String],
|
||||||
numWorkers = 2, round = 5, null, null, useExternalMemory = false)
|
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = false)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === 2)
|
assert(boosterCount === 2)
|
||||||
val boosters = boosterRDD.collect()
|
val boosters = boosterRDD.collect()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user