commit
79f9fceb6b
@ -161,10 +161,5 @@
|
||||
<version>2.2.6</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe</groupId>
|
||||
<artifactId>config</artifactId>
|
||||
<version>1.2.1</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@ -25,7 +25,7 @@
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<artifactId>xgboost4j-spark</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
|
||||
@ -0,0 +1,74 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.demo
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.linalg.DenseVector
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
|
||||
|
||||
object DistTrainWithSpark {
|
||||
|
||||
private def readFile(filePath: String): List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||
}
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||
val labelAndFeatures = line.split(" ")
|
||||
val label = labelAndFeatures(0).toInt
|
||||
val features = labelAndFeatures.tail
|
||||
val denseFeature = new Array[Double](129)
|
||||
for (feature <- features) {
|
||||
val idAndValue = feature.split(":")
|
||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||
}
|
||||
LabeledPoint(label, new DenseVector(denseFeature))
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
|
||||
if (args.length != 4) {
|
||||
println(
|
||||
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
val sc = new SparkContext()
|
||||
val inputTrainPath = args(2)
|
||||
val inputTestPath = args(3)
|
||||
val trainingLabeledPoints = readFile(inputTrainPath)
|
||||
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
|
||||
val testLabeledPoints = readFile(inputTestPath).iterator
|
||||
val testMatrix = new DMatrix(testLabeledPoints, null)
|
||||
val booster = XGBoost.train(trainRDD,
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
|
||||
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
|
||||
}
|
||||
}
|
||||
@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
|
||||
private[spark] object DataUtils extends Serializable {
|
||||
object DataUtils extends Serializable {
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPointsAsJava(
|
||||
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkToXGBoostLabeledPoints(sps).asJava
|
||||
}
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
||||
java.util.Iterator[LabeledPoint] = {
|
||||
(for (p <- sps) yield {
|
||||
Iterator[LabeledPoint] = {
|
||||
for (p <- sps) yield {
|
||||
p.features match {
|
||||
case denseFeature: DenseVector =>
|
||||
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
|
||||
@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
|
||||
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
|
||||
sparseFeature.values.map(_.toFloat))
|
||||
}
|
||||
}).asJava
|
||||
}
|
||||
|
||||
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
|
||||
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
|
||||
}
|
||||
|
||||
private def fetchUpdateFromVector(feature: Vector) = feature match {
|
||||
case denseFeature: DenseVector =>
|
||||
fetchUpdateFromSparseVector(denseFeature.toSparse)
|
||||
case sparseFeature: SparseVector =>
|
||||
fetchUpdateFromSparseVector(sparseFeature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,10 +16,11 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.immutable.HashMap
|
||||
import scala.collection.mutable
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import com.typesafe.config.Config
|
||||
import org.apache.spark.{TaskContext, SparkContext}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
@ -28,6 +29,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
|
||||
object XGBoost extends Serializable {
|
||||
|
||||
var boosters: RDD[Booster] = null
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
|
||||
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
||||
new XGBoostModel(booster)
|
||||
}
|
||||
@ -35,45 +39,51 @@ object XGBoost extends Serializable {
|
||||
private[spark] def buildDistributedBoosters(
|
||||
trainingData: RDD[LabeledPoint],
|
||||
xgBoostConfMap: Map[String, AnyRef],
|
||||
rabitEnv: mutable.Map[String, String],
|
||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||
import DataUtils._
|
||||
val sc = trainingData.sparkContext
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
if (tracker.start()) {
|
||||
trainingData.repartition(numWorkers).mapPartitions {
|
||||
trainingSamples =>
|
||||
Rabit.init(new java.util.HashMap[String, String]() {
|
||||
put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
})
|
||||
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||
Rabit.init(rabitEnv.asJava)
|
||||
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
|
||||
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
|
||||
watches = new HashMap[String, DMatrix], obj, eval)
|
||||
watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval)
|
||||
Rabit.shutdown()
|
||||
Iterator(booster)
|
||||
}.cache()
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Option[XGBoostModel] = {
|
||||
import DataUtils._
|
||||
val numWorkers = config.getInt("numWorkers")
|
||||
val round = config.getInt("round")
|
||||
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
|
||||
obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
|
||||
val numWorkers = trainingData.partitions.length
|
||||
val sc = trainingData.sparkContext
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
if (tracker.start()) {
|
||||
// TODO: build configuration map from config
|
||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
||||
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
|
||||
obj, eval)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
boosters = buildDistributedBoosters(trainingData, configMap,
|
||||
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||
@volatile var booster: Booster = null
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||
tracker.waitFor()
|
||||
// TODO: how to choose best model
|
||||
Some(boosters.first())
|
||||
boosters.foreachPartition(_ => ())
|
||||
}
|
||||
}
|
||||
sparkJobThread.start()
|
||||
val returnVal = tracker.waitFor()
|
||||
logger.info(s"Rabit returns with exit code $returnVal")
|
||||
if (returnVal == 0) {
|
||||
booster = boosters.first()
|
||||
Some(booster)
|
||||
} else {
|
||||
try {
|
||||
if (sparkJobThread.isAlive) {
|
||||
sparkJobThread.interrupt()
|
||||
}
|
||||
} catch {
|
||||
case ie: InterruptedException =>
|
||||
logger.info("spark job thread is interrupted")
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
||||
trainingRDD,
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap,
|
||||
new scala.collection.mutable.HashMap[String, String],
|
||||
numWorker, 2, null, null)
|
||||
val boosterCount = boosterRDD.count()
|
||||
assert(boosterCount === numWorker)
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* Labeled data point for training examples.
|
||||
* Represent a sparse training instance.
|
||||
*/
|
||||
public class LabeledPoint {
|
||||
public class LabeledPoint implements Serializable {
|
||||
/** Label of the point */
|
||||
public float label;
|
||||
/** Weight of this data point */
|
||||
|
||||
@ -24,7 +24,7 @@ import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Booster for xgboost, this is a model API that support interactive build of a XGBOost Model
|
||||
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
|
||||
*/
|
||||
public class Booster implements Serializable {
|
||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||
@ -353,10 +353,26 @@ public class Booster implements Serializable {
|
||||
* Save the model as byte array representation.
|
||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||
*
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray
|
||||
*
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the saved byte array.
|
||||
* @throws XGBoostError
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
byte[][] bytes = new byte[1][];
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Iterator;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
|
||||
@ -47,8 +47,15 @@ public class RabitTracker {
|
||||
while ((line = reader.readLine()) != null) {
|
||||
trackerProcessLogger.info(line);
|
||||
}
|
||||
trackerProcess.get().waitFor();
|
||||
trackerProcessLogger.info("Tracker Process ends with exit code " +
|
||||
trackerProcess.get().exitValue());
|
||||
} catch (IOException ex) {
|
||||
trackerProcessLogger.error(ex.toString());
|
||||
} catch (InterruptedException ie) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
ie.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -134,15 +141,18 @@ public class RabitTracker {
|
||||
}
|
||||
}
|
||||
|
||||
public void waitFor() {
|
||||
public int waitFor() {
|
||||
try {
|
||||
trackerProcess.get().waitFor();
|
||||
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
|
||||
int returnVal = trackerProcess.get().exitValue();
|
||||
logger.info("Tracker Process ends with exit code " + returnVal);
|
||||
stop();
|
||||
return returnVal;
|
||||
} catch (InterruptedException e) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
e.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user