commit
79f9fceb6b
@ -161,10 +161,5 @@
|
|||||||
<version>2.2.6</version>
|
<version>2.2.6</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe</groupId>
|
|
||||||
<artifactId>config</artifactId>
|
|
||||||
<version>1.2.1</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@ -25,7 +25,7 @@
|
|||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j</artifactId>
|
<artifactId>xgboost4j-spark</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|||||||
@ -0,0 +1,74 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.demo
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
import scala.io.Source
|
||||||
|
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.mllib.linalg.DenseVector
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||||
|
|
||||||
|
|
||||||
|
object DistTrainWithSpark {
|
||||||
|
|
||||||
|
private def readFile(filePath: String): List[LabeledPoint] = {
|
||||||
|
val file = Source.fromFile(new File(filePath))
|
||||||
|
val sampleList = new ListBuffer[LabeledPoint]
|
||||||
|
for (sample <- file.getLines()) {
|
||||||
|
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||||
|
}
|
||||||
|
sampleList.toList
|
||||||
|
}
|
||||||
|
|
||||||
|
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||||
|
val labelAndFeatures = line.split(" ")
|
||||||
|
val label = labelAndFeatures(0).toInt
|
||||||
|
val features = labelAndFeatures.tail
|
||||||
|
val denseFeature = new Array[Double](129)
|
||||||
|
for (feature <- features) {
|
||||||
|
val idAndValue = feature.split(":")
|
||||||
|
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||||
|
}
|
||||||
|
LabeledPoint(label, new DenseVector(denseFeature))
|
||||||
|
}
|
||||||
|
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
|
||||||
|
if (args.length != 4) {
|
||||||
|
println(
|
||||||
|
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
|
||||||
|
sys.exit(1)
|
||||||
|
}
|
||||||
|
val sc = new SparkContext()
|
||||||
|
val inputTrainPath = args(2)
|
||||||
|
val inputTestPath = args(3)
|
||||||
|
val trainingLabeledPoints = readFile(inputTrainPath)
|
||||||
|
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
|
||||||
|
val testLabeledPoints = readFile(inputTestPath).iterator
|
||||||
|
val testMatrix = new DMatrix(testLabeledPoints, null)
|
||||||
|
val booster = XGBoost.train(trainRDD,
|
||||||
|
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
|
||||||
|
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
|
|
||||||
private[spark] object DataUtils extends Serializable {
|
object DataUtils extends Serializable {
|
||||||
|
|
||||||
|
implicit def fromSparkToXGBoostLabeledPointsAsJava(
|
||||||
|
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
|
||||||
|
fromSparkToXGBoostLabeledPoints(sps).asJava
|
||||||
|
}
|
||||||
|
|
||||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
||||||
java.util.Iterator[LabeledPoint] = {
|
Iterator[LabeledPoint] = {
|
||||||
(for (p <- sps) yield {
|
for (p <- sps) yield {
|
||||||
p.features match {
|
p.features match {
|
||||||
case denseFeature: DenseVector =>
|
case denseFeature: DenseVector =>
|
||||||
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
|
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
|
||||||
@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
|
|||||||
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
|
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
|
||||||
sparseFeature.values.map(_.toFloat))
|
sparseFeature.values.map(_.toFloat))
|
||||||
}
|
}
|
||||||
}).asJava
|
}
|
||||||
}
|
|
||||||
|
|
||||||
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
|
|
||||||
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def fetchUpdateFromVector(feature: Vector) = feature match {
|
|
||||||
case denseFeature: DenseVector =>
|
|
||||||
fetchUpdateFromSparseVector(denseFeature.toSparse)
|
|
||||||
case sparseFeature: SparseVector =>
|
|
||||||
fetchUpdateFromSparseVector(sparseFeature)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,10 +16,11 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import scala.collection.immutable.HashMap
|
import scala.collection.mutable
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import com.typesafe.config.Config
|
import org.apache.commons.logging.LogFactory
|
||||||
import org.apache.spark.{TaskContext, SparkContext}
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
@ -28,6 +29,9 @@ import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
|||||||
|
|
||||||
object XGBoost extends Serializable {
|
object XGBoost extends Serializable {
|
||||||
|
|
||||||
|
var boosters: RDD[Booster] = null
|
||||||
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
|
||||||
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
||||||
new XGBoostModel(booster)
|
new XGBoostModel(booster)
|
||||||
}
|
}
|
||||||
@ -35,45 +39,51 @@ object XGBoost extends Serializable {
|
|||||||
private[spark] def buildDistributedBoosters(
|
private[spark] def buildDistributedBoosters(
|
||||||
trainingData: RDD[LabeledPoint],
|
trainingData: RDD[LabeledPoint],
|
||||||
xgBoostConfMap: Map[String, AnyRef],
|
xgBoostConfMap: Map[String, AnyRef],
|
||||||
|
rabitEnv: mutable.Map[String, String],
|
||||||
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
|
||||||
import DataUtils._
|
import DataUtils._
|
||||||
val sc = trainingData.sparkContext
|
trainingData.repartition(numWorkers).mapPartitions {
|
||||||
val tracker = new RabitTracker(numWorkers)
|
trainingSamples =>
|
||||||
if (tracker.start()) {
|
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
||||||
trainingData.repartition(numWorkers).mapPartitions {
|
Rabit.init(rabitEnv.asJava)
|
||||||
trainingSamples =>
|
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
|
||||||
Rabit.init(new java.util.HashMap[String, String]() {
|
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
|
||||||
put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
|
watches = new mutable.HashMap[String, DMatrix]{put("train", dMatrix)}.toMap, obj, eval)
|
||||||
})
|
Rabit.shutdown()
|
||||||
val dMatrix = new DMatrix(new JDMatrix(trainingSamples, null))
|
Iterator(booster)
|
||||||
val booster = SXGBoost.train(xgBoostConfMap, dMatrix, round,
|
}.cache()
|
||||||
watches = new HashMap[String, DMatrix], obj, eval)
|
|
||||||
Rabit.shutdown()
|
|
||||||
Iterator(booster)
|
|
||||||
}.cache()
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
|
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, AnyRef], round: Int,
|
||||||
eval: EvalTrait = null): Option[XGBoostModel] = {
|
obj: ObjectiveTrait = null, eval: EvalTrait = null): Option[XGBoostModel] = {
|
||||||
import DataUtils._
|
val numWorkers = trainingData.partitions.length
|
||||||
val numWorkers = config.getInt("numWorkers")
|
|
||||||
val round = config.getInt("round")
|
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val tracker = new RabitTracker(numWorkers)
|
val tracker = new RabitTracker(numWorkers)
|
||||||
if (tracker.start()) {
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
// TODO: build configuration map from config
|
boosters = buildDistributedBoosters(trainingData, configMap,
|
||||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
|
||||||
val boosters = buildDistributedBoosters(trainingData, xgBoostConfigMap, numWorkers, round,
|
@volatile var booster: Booster = null
|
||||||
obj, eval)
|
val sparkJobThread = new Thread() {
|
||||||
// force the job
|
override def run() {
|
||||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
// force the job
|
||||||
tracker.waitFor()
|
boosters.foreachPartition(_ => ())
|
||||||
// TODO: how to choose best model
|
}
|
||||||
Some(boosters.first())
|
}
|
||||||
|
sparkJobThread.start()
|
||||||
|
val returnVal = tracker.waitFor()
|
||||||
|
logger.info(s"Rabit returns with exit code $returnVal")
|
||||||
|
if (returnVal == 0) {
|
||||||
|
booster = boosters.first()
|
||||||
|
Some(booster)
|
||||||
} else {
|
} else {
|
||||||
|
try {
|
||||||
|
if (sparkJobThread.isAlive) {
|
||||||
|
sparkJobThread.interrupt()
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
case ie: InterruptedException =>
|
||||||
|
logger.info("spark job thread is interrupted")
|
||||||
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -130,6 +130,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
|
|||||||
trainingRDD,
|
trainingRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
|
new scala.collection.mutable.HashMap[String, String],
|
||||||
numWorker, 2, null, null)
|
numWorker, 2, null, null)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === numWorker)
|
assert(boosterCount === numWorker)
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
package ml.dmlc.xgboost4j;
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Labeled data point for training examples.
|
* Labeled data point for training examples.
|
||||||
* Represent a sparse training instance.
|
* Represent a sparse training instance.
|
||||||
*/
|
*/
|
||||||
public class LabeledPoint {
|
public class LabeledPoint implements Serializable {
|
||||||
/** Label of the point */
|
/** Label of the point */
|
||||||
public float label;
|
public float label;
|
||||||
/** Weight of this data point */
|
/** Weight of this data point */
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import org.apache.commons.logging.Log;
|
|||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Booster for xgboost, this is a model API that support interactive build of a XGBOost Model
|
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
|
||||||
*/
|
*/
|
||||||
public class Booster implements Serializable {
|
public class Booster implements Serializable {
|
||||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||||
@ -353,10 +353,26 @@ public class Booster implements Serializable {
|
|||||||
* Save the model as byte array representation.
|
* Save the model as byte array representation.
|
||||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||||
*
|
*
|
||||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
* If java natively support HDFS file API, use toByteArray and write the ByteArray
|
||||||
|
*
|
||||||
|
* @param withStats Controls whether the split statistics are output.
|
||||||
|
* @return dumped model information
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
||||||
|
int statsFlag = 0;
|
||||||
|
if (withStats) {
|
||||||
|
statsFlag = 1;
|
||||||
|
}
|
||||||
|
String[][] modelInfos = new String[1][];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||||
|
return modelInfos[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
*
|
*
|
||||||
* @return the saved byte array.
|
* @return the saved byte array.
|
||||||
* @throws XGBoostError
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public byte[] toByteArray() throws XGBoostError {
|
public byte[] toByteArray() throws XGBoostError {
|
||||||
byte[][] bytes = new byte[1][];
|
byte[][] bytes = new byte[1][];
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
|
|||||||
@ -47,8 +47,15 @@ public class RabitTracker {
|
|||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
trackerProcessLogger.info(line);
|
trackerProcessLogger.info(line);
|
||||||
}
|
}
|
||||||
|
trackerProcess.get().waitFor();
|
||||||
|
trackerProcessLogger.info("Tracker Process ends with exit code " +
|
||||||
|
trackerProcess.get().exitValue());
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
trackerProcessLogger.error(ex.toString());
|
trackerProcessLogger.error(ex.toString());
|
||||||
|
} catch (InterruptedException ie) {
|
||||||
|
// we should not get here as RabitTracker is accessed in the main thread
|
||||||
|
ie.printStackTrace();
|
||||||
|
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -134,15 +141,18 @@ public class RabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void waitFor() {
|
public int waitFor() {
|
||||||
try {
|
try {
|
||||||
trackerProcess.get().waitFor();
|
trackerProcess.get().waitFor();
|
||||||
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
|
int returnVal = trackerProcess.get().exitValue();
|
||||||
|
logger.info("Tracker Process ends with exit code " + returnVal);
|
||||||
stop();
|
stop();
|
||||||
|
return returnVal;
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
// we should not get here as RabitTracker is accessed in the main thread
|
// we should not get here as RabitTracker is accessed in the main thread
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user