revise current API

This commit is contained in:
CodingCat
2016-03-07 21:48:16 -05:00
parent 9911771b02
commit fa03aaeb63
9 changed files with 170 additions and 64 deletions

View File

@@ -19,23 +19,21 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError, Rabit, RabitTracker}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
object XGBoost extends Serializable {
var boosters: RDD[Booster] = null
private val logger = LogFactory.getLog("XGBoostSpark")
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
private implicit def convertBoosterToXGBoostModel(booster: Booster)
(implicit sc: SparkContext): XGBoostModel = {
new XGBoostModel(booster)
}
@@ -57,27 +55,36 @@ object XGBoost extends Serializable {
}.cache()
}
/**
*
* @param trainingData the trainingset represented as RDD
* @param configMap Map containing the configuration entries
* @param round the number of iterations
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
val numWorkers = trainingData.partitions.length
val sc = trainingData.sparkContext
implicit val sc = trainingData.sparkContext
val tracker = new RabitTracker(numWorkers)
require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, configMap,
tracker.getWorkerEnvs.asScala, numWorkers, round, obj, eval)
@volatile var booster: Booster = null
val sparkJobThread = new Thread() {
override def run() {
// force the job
boosters.foreachPartition(_ => ())
boosters.foreachPartition(() => _)
}
}
sparkJobThread.start()
val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal")
if (returnVal == 0) {
booster = boosters.first()
Some(booster).get
boosters.first()
} else {
try {
if (sparkJobThread.isAlive) {
@@ -87,21 +94,20 @@ object XGBoost extends Serializable {
case ie: InterruptedException =>
logger.info("spark job thread is interrupted")
}
null
throw new XGBoostError("XGBoostModel training failed")
}
}
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
new XGBoostModel(
SXGBoost.loadModel(
FileSystem
.get(new Configuration)
.open(new Path(modelPath))))
* Load XGBoost model from path in HDFS-compatible file system
*
* @param modelPath The path of the file representing the model
* @return The loaded model
*/
def loadModelFromHadoop(modelPath: String)(implicit sparkContext: SparkContext): XGBoostModel = {
val dataInStream = FileSystem.get(sparkContext.hadoopConfiguration).open(new Path(modelPath))
val xgBoostModel = new XGBoostModel(SXGBoost.loadModel(dataInStream))
dataInStream.close()
xgBoostModel
}
}

View File

@@ -16,18 +16,17 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
class XGBoostModel(booster: Booster) extends Serializable {
class XGBoostModel(booster: Booster)(implicit val sc: SparkContext) extends Serializable {
/**
* Predict result given testRDD
* @param testSet the testSet of Data vectors
* @return The predicted RDD
* Predict result with the given testset (represented as RDD)
*/
def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
import DataUtils._
@@ -39,18 +38,21 @@ class XGBoostModel(booster: Booster) extends Serializable {
}
}
/**
* predict result given the test data (represented as DMatrix)
*/
def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet)
booster.predict(testSet, true, 0)
}
/**
* Save the model as a Hadoop filesystem file.
*
* Save the model as to HDFS-compatible file system.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModelToHadoop(modelPath: String): Unit = {
booster.saveModel(FileSystem
.get(new Configuration)
.create(new Path(modelPath)))
val outputStream = FileSystem.get(sc.hadoopConfiguration).create(new Path(modelPath))
booster.saveModel(outputStream)
outputStream.close()
}
}

View File

@@ -17,13 +17,11 @@
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
import scala.collection.mutable.ListBuffer
import scala.io.Source
import scala.tools.reflect.Eval
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
import org.apache.commons.logging.LogFactory
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -31,10 +29,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, XGBoostError}
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
private var sc: SparkContext = null
private val numWorker = 4
private implicit var sc: SparkContext = null
private val numWorker = 2
private class EvalError extends EvalTrait {
@@ -111,14 +112,9 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
sampleList.toList
}
private def buildRDD(filePath: String): RDD[LabeledPoint] = {
val sampleList = readFile(filePath)
sc.parallelize(sampleList, numWorker)
}
private def buildTrainingRDD(): RDD[LabeledPoint] = {
val trainRDD = buildRDD(getClass.getResource("/agaricus.txt.train").getFile)
trainRDD
val sampleList = readFile(getClass.getResource("/agaricus.txt.train").getFile)
sc.parallelize(sampleList, numWorker)
}
test("build RDD containing boosters") {
@@ -140,4 +136,23 @@ class XGBoostSuite extends FunSuite with BeforeAndAfterAll {
assert(new EvalError().eval(predicts, testSetDMatrix) < 0.1)
}
}
test("save and load model") {
val eval = new EvalError()
val trainingRDD = buildTrainingRDD()
val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
import DataUtils._
val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5)
assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
xgBoostModel.saveModelToHadoop(tempFile.toFile.getAbsolutePath)
val loadedXGBooostModel = XGBoost.loadModelFromHadoop(tempFile.toFile.getAbsolutePath)
val predicts = loadedXGBooostModel.predict(testSetDMatrix)
assert(eval.eval(predicts, testSetDMatrix) < 0.1)
}
}