revise current API

This commit is contained in:
CodingCat 2016-03-07 21:48:16 -05:00
parent 9911771b02
commit fa03aaeb63
9 changed files with 170 additions and 64 deletions

@ -1 +1 @@
Subproject commit 969fb6455ae41d5d2f7c4ba8921f4885e9aa63c8 Subproject commit 4e6459b0bc15e6cf9b315cc75e2e5495c03cd417

View File

@ -38,8 +38,8 @@ object DistTrainWithSpark {
"eta" -> 0.1f, "eta" -> 0.1f,
"max_depth" -> 2, "max_depth" -> 2,
"objective" -> "binary:logistic").toMap "objective" -> "binary:logistic").toMap
val model = XGBoost.train(trainRDD, paramMap, numRound) val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound)
// save model to HDFS path // save model to HDFS path
model.saveModelToHadoop(outputModelPath) xgboostModel.saveModelToHadoop(outputModelPath)
} }
} }

View File

@ -19,23 +19,21 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable import scala.collection.mutable
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker} import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError, Rabit, RabitTracker}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost extends Serializable { object XGBoost extends Serializable {
var boosters: RDD[Booster] = null
private val logger = LogFactory.getLog("XGBoostSpark") private val logger = LogFactory.getLog("XGBoostSpark")
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = { private implicit def convertBoosterToXGBoostModel(booster: Booster)
(implicit sc: SparkContext): XGBoostModel = {
new XGBoostModel(booster) new XGBoostModel(booster)
} }
@ -57,27 +55,36 @@ object XGBoost extends Serializable {
}.cache() }.cache()
} }
/**
*
* @param trainingData the trainingset represented as RDD
* @param configMap Map containing the configuration entries
* @param round the number of iterations
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int, def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = { obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
val numWorkers = trainingData.partitions.length val numWorkers = trainingData.partitions.length
val sc = trainingData.sparkContext implicit val sc = trainingData.sparkContext
val tracker = new RabitTracker(numWorkers) val tracker = new RabitTracker(numWorkers)
require(tracker.start(), "FAULT: Failed to start tracker") require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, configMap, val boosters = buildDistributedBoosters(trainingData, configMap,
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval) tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
@volatile var booster: Booster = null
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
override def run() { override def run() {
// force the job // force the job
boosters.foreachPartition(_ => ()) boosters.foreachPartition(() => _)
} }
} }
sparkJobThread.start() sparkJobThread.start()
val returnVal = tracker.waitFor() val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal") logger.info(s"Rabit returns with exit code $returnVal")
if (returnVal == 0) { if (returnVal == 0) {
booster = boosters.first() boosters.first()
Some(booster).get
} else { } else {
try { try {
if (sparkJobThread.isAlive) { if (sparkJobThread.isAlive) {
@ -87,21 +94,20 @@ object XGBoost extends Serializable {
case ie: InterruptedException => case ie: InterruptedException =>
logger.info("spark job thread is interrupted") logger.info("spark job thread is interrupted")
} }
null throw new XGBoostError("XGBoostModel training failed")
} }
} }
/** /**
* Load XGBoost model from path, using Hadoop Filesystem API. * Load XGBoost model from path in HDFS-compatible file system
* *
* @param modelPath The path that is accessible by hadoop filesystem API. * @param modelPath The path of the file representing the model
* @return The loaded model * @return The loaded model
*/ */
def loadModelFromHadoop(modelPath: String) : XGBoostModel = { def loadModelFromHadoop(modelPath: String)(implicit sparkContext: SparkContext): XGBoostModel = {
new XGBoostModel( val dataInStream = FileSystem.get(sparkContext.hadoopConfiguration).open(new Path(modelPath))
SXGBoost.loadModel( val xgBoostModel = new XGBoostModel(SXGBoost.loadModel(dataInStream))
FileSystem dataInStream.close()
.get(new Configuration) xgBoostModel
.open(new Path(modelPath))))
} }
} }

View File

@ -16,18 +16,17 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
class XGBoostModel(booster: Booster) extends Serializable { class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
/** /**
* Predict result given testRDD * Predict result with the given testset (represented as RDD)
* @param testSet the testSet of Data vectors
* @return The predicted RDD
*/ */
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = { def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._ import DataUtils._
@ -39,18 +38,21 @@ class XGBoostModel(booster: Booster) extends Serializable {
} }
} }
/**
* predict result given the test data (represented as DMatrix)
*/
def predict(testSet: DMatrix): Array[Array[Float]] = { def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet) booster.predict(testSet, true, 0)
} }
/** /**
* Save the model as a Hadoop filesystem file. * Save the model as to HDFS-compatible file system.
* *
* @param modelPath The model path as in Hadoop path. * @param modelPath The model path as in Hadoop path.
*/ */
def saveModelToHadoop(modelPath: String): Unit = { def saveModelToHadoop(modelPath: String): Unit = {
booster.saveModel(FileSystem val outputStream = FileSystem.get(sc.hadoopConfiguration).create(new Path(modelPath))
.get(new Configuration) booster.saveModel(outputStream)
.create(new Path(modelPath))) outputStream.close()
} }
} }

View File

@ -17,13 +17,11 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import java.nio.file.Files
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
import scala.io.Source import scala.io.Source
import scala.tools.reflect.Eval
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
@ -31,10 +29,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.{BeforeAndAfterAll, FunSuite}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
class XGBoostSuite extends FunSuite with BeforeAndAfterAll { class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
private var sc: SparkContext = null private implicit var sc: SparkContext = null
private val numWorker = 4 private val numWorker = 2
private class EvalError extends EvalTrait { private class EvalError extends EvalTrait {
@ -111,14 +112,9 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
sampleList.toList sampleList.toList
} }
private def buildRDD(filePath: String): RDD[LabeledPoint] = {
val sampleList = readFile(filePath)
sc.parallelize(sampleList, numWorker)
}
private def buildTrainingRDD(): RDD[LabeledPoint] = { private def buildTrainingRDD(): RDD[LabeledPoint] = {
val trainRDD = buildRDD(getClass.getResource("/agaricus.txt.train").getFile) val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
trainRDD sc.parallelize(sampleList, numWorker)
} }
test("build RDD containing boosters") { test("build RDD containing boosters") {
@ -140,4 +136,23 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1) assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1)
} }
} }
test("save and load model") {
val eval = new EvalError()
val trainingRDD = buildTrainingRDD()
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5)
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
xgBoostModel.saveModelToHadoop(tempFile.toFile.getAbsolutePath)
val loadedXGBooostModel = XGBoost.loadModelFromHadoop(tempFile.toFile.getAbsolutePath)
val predicts = loadedXGBooostModel.predict(testSetDMatrix)
assert(eval.eval(predicts, testSetDMatrix) < 0.1)
}
} }

View File

@ -129,6 +129,7 @@ public class RabitTracker {
public boolean start() { public boolean start() {
if (startTrackerProcess()) { if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString()); logger.debug("Tracker started, with env=" + envs.toString());
System.out.println("Tracker started, with env=" + envs.toString());
// also start a tracker logger // also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger()); Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true); logger_thread.setDaemon(true);

View File

@ -119,7 +119,7 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
*/ */
@throws(classOf[XGBoostError]) @throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0) def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0)
: Array[Array[Float]] = { : Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit) booster.predict(data.jDMatrix, outPutMargin, treeLimit)
} }
@ -178,6 +178,10 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
booster.getFeatureScore(featureMap).asScala booster.getFeatureScore(featureMap).asScala
} }
def toByteArray: Array[Byte] = {
booster.toByteArray
}
/** /**
* Dispose the booster when it is no longer needed * Dispose the booster when it is no longer needed
*/ */

View File

@ -17,7 +17,10 @@ package ml.dmlc.xgboost4j.java;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -70,11 +73,7 @@ public class BoosterImplTest {
} }
} }
@Test private Booster trainBooster(DMatrix trainMat, DMatrix testMat) throws XGBoostError {
public void testBoosterBasic() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//set params //set params
Map<String, Object> paramMap = new HashMap<String, Object>() { Map<String, Object> paramMap = new HashMap<String, Object>() {
{ {
@ -92,10 +91,19 @@ public class BoosterImplTest {
watches.put("test", testMat); watches.put("test", testMat);
//set round //set round
int round = 2; int round = 5;
//train a boost model //train a boost model
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null); return XGBoost.train(paramMap, trainMat, round, watches, null, null);
}
@Test
public void testBoosterBasic() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
//predict raw output //predict raw output
float[][] predicts = booster.predict(testMat, true, 0); float[][] predicts = booster.predict(testMat, true, 0);
@ -104,14 +112,43 @@ public class BoosterImplTest {
IEvaluation eval = new EvalError(); IEvaluation eval = new EvalError();
//error must be less than 0.1 //error must be less than 0.1
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f); TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
}
@Test
public void saveLoadModelWithPath() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
IEvaluation eval = new EvalError();
Booster booster = trainBooster(trainMat, testMat);
// save and load // save and load
File temp = File.createTempFile("temp", "model"); File temp = File.createTempFile("temp", "model");
temp.deleteOnExit(); temp.deleteOnExit();
booster.saveModel(temp.getAbsolutePath()); booster.saveModel(temp.getAbsolutePath());
Booster bst2 = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath())); Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray())); assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray()));
float[][] predicts2 = bst2.predict(testMat, true, 0);
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
}
@Test
public void saveLoadModelWithStream() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
Path tempDir = Files.createTempDirectory("boosterTest-");
File tempFile = Files.createTempFile("", "").toFile();
booster.saveModel(new FileOutputStream(tempFile));
IEvaluation eval = new EvalError();
Booster loadedBooster = XGBoost.loadModel(new FileInputStream(tempFile));
float originalPredictError = eval.eval(booster.predict(testMat, true), testMat);
TestCase.assertTrue("originalPredictErr:" + originalPredictError,
originalPredictError < 0.1f);
float loadedPredictError = eval.eval(loadedBooster.predict(testMat, true), testMat);
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
} }
/** /**

View File

@ -16,10 +16,14 @@
package ml.dmlc.xgboost4j.scala package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.XGBoostError import java.io.{FileOutputStream, FileInputStream, File}
import junit.framework.TestCase
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.scalatest.FunSuite import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.java.XGBoostError
class ScalaBoosterImplSuite extends FunSuite { class ScalaBoosterImplSuite extends FunSuite {
private class EvalError extends EvalTrait { private class EvalError extends EvalTrait {
@ -64,21 +68,58 @@ class ScalaBoosterImplSuite extends FunSuite {
} }
} }
test("basic operation of booster") { private def trainBooster(trainMat: DMatrix, testMat: DMatrix): Booster = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap "objective" -> "binary:logistic").toMap
val watches = List("train" -> trainMat, "test" -> testMat).toMap val watches = List("train" -> trainMat, "test" -> testMat).toMap
val round = 2 val round = 2
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null) XGBoost.train(paramMap, trainMat, round, watches, null, null)
}
test("basic operation of booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val booster = trainBooster(trainMat, testMat)
val predicts = booster.predict(testMat, true) val predicts = booster.predict(testMat, true)
val eval = new EvalError val eval = new EvalError
assert(eval.eval(predicts, testMat) < 0.1) assert(eval.eval(predicts, testMat) < 0.1)
} }
test("save/load model with path") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
val temp: File = File.createTempFile("temp", "model")
temp.deleteOnExit()
booster.saveModel(temp.getAbsolutePath)
val bst2: Booster = XGBoost.loadModel(temp.getAbsolutePath)
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
}
test("save/load model with stream") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
val temp: File = File.createTempFile("temp", "model")
temp.deleteOnExit()
booster.saveModel(new FileOutputStream(temp.getAbsolutePath))
val bst2: Booster = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath))
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
}
test("cross validation") { test("cross validation") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6", val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",