[jvm-packages] xgboost4j-spark external memory (#1219)

* implement external memory support for XGBoost4J

* remove extra space

* enable external memory for prediction

* update doc
This commit is contained in:
Nan Zhu 2016-05-22 14:01:28 -04:00
parent 587999755f
commit c85b9012c6
8 changed files with 62 additions and 16 deletions

View File

@ -61,7 +61,6 @@ object DistTrainWithSpark {
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val sc = new SparkContext(sparkConf)
val sc = new SparkContext(sparkConf)
val inputTrainPath = args(1)
val outputModelPath = args(2)
// number of iterations
@ -73,7 +72,8 @@ object DistTrainWithSpark {
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// use 5 distributed workers to train the model
val model = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = 5)
// useExternalMemory indicates whether
val model = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = 5, useExternalMemory = true)
// save model to HDFS path
model.saveModelToHadoop(outputModelPath)
}

View File

@ -32,8 +32,8 @@ public class ExternalMemory {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
DMatrix trainMat = new DMatrix("../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../demo/data/agaricus.txt.test#dtest.cache");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -28,7 +28,7 @@ object DistTrainWithSpark {
"usage: program num_of_rounds num_workers training_path test_path model_path")
sys.exit(1)
}
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example")
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val sc = new SparkContext(sparkConf)
@ -45,7 +45,8 @@ object DistTrainWithSpark {
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt)
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
useExternalMemory = true)
xgboostModel.predict(new DMatrix(testSet))
// save model to HDFS path
xgboostModel.saveModelAsHadoopFile(outputModelPath)

View File

@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Paths
import scala.collection.mutable
import scala.collection.JavaConverters._
@ -41,7 +43,8 @@ object XGBoost extends Serializable {
trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, Any],
rabitEnv: mutable.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
useExternalMemory: Boolean): RDD[Booster] = {
import DataUtils._
val partitionedData = {
if (numWorkers > trainingData.partitions.length) {
@ -54,11 +57,19 @@ object XGBoost extends Serializable {
trainingData
}
}
val appName = partitionedData.context.appName
partitionedData.mapPartitions {
trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, null))
val cacheFileName: String = {
if (useExternalMemory && trainingSamples.hasNext) {
s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
obj, eval)
@ -76,12 +87,15 @@ object XGBoost extends Serializable {
* workers equals to the partition number of trainingData RDD
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
useExternalMemory: Boolean = false): XGBoostModel = {
require(nWorkers > 0, "you must specify more than 0 workers")
val tracker = new RabitTracker(nWorkers)
implicit val sc = trainingData.sparkContext
@ -97,7 +111,7 @@ object XGBoost extends Serializable {
}
require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval)
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory)
val sparkJobThread = new Thread() {
override def run() {
// force the job

View File

@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.SparkContext
import org.apache.spark.{TaskContext, SparkContext}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
@ -27,13 +27,23 @@ class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Ser
/**
* Predict result with the given testset (represented as RDD)
* @param testSet test set representd as RDD
* @param useExternalCache whether to use external cache for the test set
*/
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
val appName = testSet.context.appName
testSet.mapPartitions { testSamples =>
if (testSamples.hasNext) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
val cacheFileName = {
if (useExternalCache) {
s"$appName-dtest_cache-${TaskContext.getPartitionId()}"
} else {
null
}
}
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
Iterator(broadcastBooster.value.predict(dMatrix))
} else {
Iterator()

View File

@ -127,7 +127,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorkers = 2, round = 5, null, null)
numWorkers = 2, round = 5, null, null, false)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
val boosters = boosterRDD.collect()
@ -210,4 +210,26 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
println(xgBoostModel.predict(testRDD))
}
test("training with external memory cache") {
sc.stop()
sc = null
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
val customSparkContext = new SparkContext(sparkConf)
val eval = new EvalError()
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers, useExternalMemory = true)
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
customSparkContext.stop()
// clean
val dir = new File(".")
for (file <- dir.listFiles() if file.getName.startsWith("XGBoostSuite-dtrain_cache")) {
file.delete()
}
}
}

View File

@ -201,7 +201,6 @@ void SparsePageSource::Create(DMatrix* src,
<< (bytes_write >> 20UL) << " written";
}
}
if (page->data.size() != 0) {
writer.PushWrite(std::move(page));
}