[JVM] Refactor, add filesys API

This commit is contained in:
tqchen
2016-03-06 11:33:48 -08:00
parent 457ff82e33
commit 56f7a414d1
17 changed files with 597 additions and 896 deletions

View File

@@ -37,6 +37,8 @@ object Test {
"objective" -> "binary:logistic").toMap
val round = 2
val model = XGBoost.train(paramMap, data, round)
log.info(model)
}
}

View File

@@ -25,6 +25,9 @@ import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala._
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.util.Collector
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
object XGBoost {
/**
@@ -60,6 +63,20 @@ object XGBoost {
val logger = LogFactory.getLog(this.getClass)
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModel(modelPath: String) : XGBoostModel = {
new XGBoostModel(
XGBoostScala.loadModel(
FileSystem
.get(new Configuration)
.open(new Path(modelPath))))
}
/**
* Train a xgboost model with link.
*

View File

@@ -16,8 +16,45 @@
package ml.dmlc.xgboost4j.flink
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala._
import org.apache.flink.ml.math.Vector
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
class XGBoostModel (booster: Booster) extends Serializable {
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModel(modelPath: String): Unit = {
booster.saveModel(FileSystem
.get(new Configuration)
.create(new Path(modelPath)))
}
/**
* Predict given vector dataset.
*
* @param data The dataset to be predicted.
* @return The prediction result.
*/
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
val predictMap: Iterator[Vector] => TraversableOnce[Array[Float]] =
(it: Iterator[Vector]) => {
val mapper = (x: Vector) => {
val (index, value) = x.toSeq.unzip
LabeledPoint.fromSparseVector(0.0f,
index.toArray, value.map(z => z.toFloat).toArray)
}
val dataIter = for (x <- it) yield mapper(x)
val dmat = new DMatrix(dataIter, null)
this.booster.predict(dmat)
}
data.mapPartition(predictMap)
}
}