[jvm-packages] xgboost4j-spark external memory (#1219)
* implement external memory support for XGBoost4J * remove extra space * enable external memory for prediction * update doc
This commit is contained in:
parent
587999755f
commit
c85b9012c6
@ -61,7 +61,6 @@ object DistTrainWithSpark {
|
|||||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||||
val sc = new SparkContext(sparkConf)
|
val sc = new SparkContext(sparkConf)
|
||||||
val sc = new SparkContext(sparkConf)
|
|
||||||
val inputTrainPath = args(1)
|
val inputTrainPath = args(1)
|
||||||
val outputModelPath = args(2)
|
val outputModelPath = args(2)
|
||||||
// number of iterations
|
// number of iterations
|
||||||
@ -73,7 +72,8 @@ object DistTrainWithSpark {
|
|||||||
"max_depth" -> 2,
|
"max_depth" -> 2,
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
// use 5 distributed workers to train the model
|
// use 5 distributed workers to train the model
|
||||||
val model = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = 5)
|
// useExternalMemory indicates whether
|
||||||
|
val model = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = 5, useExternalMemory = true)
|
||||||
// save model to HDFS path
|
// save model to HDFS path
|
||||||
model.saveModelToHadoop(outputModelPath)
|
model.saveModelToHadoop(outputModelPath)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,8 +32,8 @@ public class ExternalMemory {
|
|||||||
//this is the only difference, add a # followed by a cache prefix name
|
//this is the only difference, add a # followed by a cache prefix name
|
||||||
//several cache file with the prefix will be generated
|
//several cache file with the prefix will be generated
|
||||||
//currently only support convert from libsvm file
|
//currently only support convert from libsvm file
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
DMatrix trainMat = new DMatrix("../demo/data/agaricus.txt.train#dtrain.cache");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
DMatrix testMat = new DMatrix("../demo/data/agaricus.txt.test#dtest.cache");
|
||||||
|
|
||||||
//specify parameters
|
//specify parameters
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
|
|||||||
@ -28,7 +28,7 @@ object DistTrainWithSpark {
|
|||||||
"usage: program num_of_rounds num_workers training_path test_path model_path")
|
"usage: program num_of_rounds num_workers training_path test_path model_path")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
}
|
}
|
||||||
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example")
|
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
|
||||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||||
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
sparkConf.registerKryoClasses(Array(classOf[Booster]))
|
||||||
val sc = new SparkContext(sparkConf)
|
val sc = new SparkContext(sparkConf)
|
||||||
@ -45,7 +45,8 @@ object DistTrainWithSpark {
|
|||||||
"eta" -> 0.1f,
|
"eta" -> 0.1f,
|
||||||
"max_depth" -> 2,
|
"max_depth" -> 2,
|
||||||
"objective" -> "binary:logistic").toMap
|
"objective" -> "binary:logistic").toMap
|
||||||
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt)
|
val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
|
||||||
|
useExternalMemory = true)
|
||||||
xgboostModel.predict(new DMatrix(testSet))
|
xgboostModel.predict(new DMatrix(testSet))
|
||||||
// save model to HDFS path
|
// save model to HDFS path
|
||||||
xgboostModel.saveModelAsHadoopFile(outputModelPath)
|
xgboostModel.saveModelAsHadoopFile(outputModelPath)
|
||||||
|
|||||||
@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.nio.file.Paths
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
@ -41,7 +43,8 @@ object XGBoost extends Serializable {
|
|||||||
trainingData: RDD[LabeledPoint],
|
trainingData: RDD[LabeledPoint],
|
||||||
xgBoostConfMap: Map[String, Any],
|
xgBoostConfMap: Map[String, Any],
|
||||||
rabitEnv: mutable.Map[String, String],
|
rabitEnv: mutable.Map[String, String],
|
||||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
|
||||||
|
useExternalMemory: Boolean): RDD[Booster] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val partitionedData = {
|
val partitionedData = {
|
||||||
if (numWorkers > trainingData.partitions.length) {
|
if (numWorkers > trainingData.partitions.length) {
|
||||||
@ -54,11 +57,19 @@ object XGBoost extends Serializable {
|
|||||||
trainingData
|
trainingData
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
val appName = partitionedData.context.appName
|
||||||
partitionedData.mapPartitions {
|
partitionedData.mapPartitions {
|
||||||
trainingSamples =>
|
trainingSamples =>
|
||||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
Rabit.init(rabitEnv.asJava)
|
Rabit.init(rabitEnv.asJava)
|
||||||
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, null))
|
val cacheFileName: String = {
|
||||||
|
if (useExternalMemory && trainingSamples.hasNext) {
|
||||||
|
s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
|
||||||
val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
|
||||||
watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
|
watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
|
||||||
obj, eval)
|
obj, eval)
|
||||||
@ -76,12 +87,15 @@ object XGBoost extends Serializable {
|
|||||||
* workers equals to the partition number of trainingData RDD
|
* workers equals to the partition number of trainingData RDD
|
||||||
* @param obj the user-defined objective function, null by default
|
* @param obj the user-defined objective function, null by default
|
||||||
* @param eval the user-defined evaluation function, null by default
|
* @param eval the user-defined evaluation function, null by default
|
||||||
|
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
|
||||||
|
* true, the user may save the RAM cost for running XGBoost within Spark
|
||||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
|
||||||
* @return XGBoostModel when successful training
|
* @return XGBoostModel when successful training
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
|
||||||
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
|
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
|
||||||
|
useExternalMemory: Boolean = false): XGBoostModel = {
|
||||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||||
val tracker = new RabitTracker(nWorkers)
|
val tracker = new RabitTracker(nWorkers)
|
||||||
implicit val sc = trainingData.sparkContext
|
implicit val sc = trainingData.sparkContext
|
||||||
@ -97,7 +111,7 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
|
||||||
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval)
|
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory)
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
// force the job
|
// force the job
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import org.apache.hadoop.fs.{Path, FileSystem}
|
import org.apache.hadoop.fs.{Path, FileSystem}
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.{TaskContext, SparkContext}
|
||||||
import org.apache.spark.mllib.linalg.Vector
|
import org.apache.spark.mllib.linalg.Vector
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
@ -27,13 +27,23 @@ class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Ser
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict result with the given testset (represented as RDD)
|
* Predict result with the given testset (represented as RDD)
|
||||||
|
* @param testSet test set representd as RDD
|
||||||
|
* @param useExternalCache whether to use external cache for the test set
|
||||||
*/
|
*/
|
||||||
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
|
def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
|
||||||
|
val appName = testSet.context.appName
|
||||||
testSet.mapPartitions { testSamples =>
|
testSet.mapPartitions { testSamples =>
|
||||||
if (testSamples.hasNext) {
|
if (testSamples.hasNext) {
|
||||||
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
|
val cacheFileName = {
|
||||||
|
if (useExternalCache) {
|
||||||
|
s"$appName-dtest_cache-${TaskContext.getPartitionId()}"
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
|
||||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||||
} else {
|
} else {
|
||||||
Iterator()
|
Iterator()
|
||||||
|
|||||||
@ -127,7 +127,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
new scala.collection.mutable.HashMap[String, String],
|
new scala.collection.mutable.HashMap[String, String],
|
||||||
numWorkers = 2, round = 5, null, null)
|
numWorkers = 2, round = 5, null, null, false)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === 2)
|
assert(boosterCount === 2)
|
||||||
val boosters = boosterRDD.collect()
|
val boosters = boosterRDD.collect()
|
||||||
@ -210,4 +210,26 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
|
|||||||
|
|
||||||
println(xgBoostModel.predict(testRDD))
|
println(xgBoostModel.predict(testRDD))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("training with external memory cache") {
|
||||||
|
sc.stop()
|
||||||
|
sc = null
|
||||||
|
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
|
||||||
|
val customSparkContext = new SparkContext(sparkConf)
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = buildTrainingRDD(Some(customSparkContext))
|
||||||
|
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
|
||||||
|
import DataUtils._
|
||||||
|
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers, useExternalMemory = true)
|
||||||
|
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
|
||||||
|
customSparkContext.stop()
|
||||||
|
// clean
|
||||||
|
val dir = new File(".")
|
||||||
|
for (file <- dir.listFiles() if file.getName.startsWith("XGBoostSuite-dtrain_cache")) {
|
||||||
|
file.delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -201,7 +201,6 @@ void SparsePageSource::Create(DMatrix* src,
|
|||||||
<< (bytes_write >> 20UL) << " written";
|
<< (bytes_write >> 20UL) << " written";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (page->data.size() != 0) {
|
if (page->data.size() != 0) {
|
||||||
writer.PushWrite(std::move(page));
|
writer.PushWrite(std::move(page));
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user