[JVM] Refactor, add filesys API
This commit is contained in:
@@ -37,6 +37,8 @@ object Test {
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val round = 2
|
||||
val model = XGBoost.train(paramMap, data, round)
|
||||
|
||||
|
||||
log.info(model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,9 @@ import org.apache.flink.api.scala.DataSet
|
||||
import org.apache.flink.api.scala._
|
||||
import org.apache.flink.ml.common.LabeledVector
|
||||
import org.apache.flink.util.Collector
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
||||
object XGBoost {
|
||||
/**
|
||||
@@ -60,6 +63,20 @@ object XGBoost {
|
||||
|
||||
val logger = LogFactory.getLog(this.getClass)
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModel(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
XGBoostScala.loadModel(
|
||||
FileSystem
|
||||
.get(new Configuration)
|
||||
.open(new Path(modelPath))))
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a xgboost model with link.
|
||||
*
|
||||
|
||||
@@ -16,8 +16,45 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
import org.apache.flink.api.scala.DataSet
|
||||
import org.apache.flink.api.scala._
|
||||
import org.apache.flink.ml.math.Vector
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
||||
class XGBoostModel (booster: Booster) extends Serializable {
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModel(modelPath: String): Unit = {
|
||||
booster.saveModel(FileSystem
|
||||
.get(new Configuration)
|
||||
.create(new Path(modelPath)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict given vector dataset.
|
||||
*
|
||||
* @param data The dataset to be predicted.
|
||||
* @return The prediction result.
|
||||
*/
|
||||
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
|
||||
val predictMap: Iterator[Vector] => TraversableOnce[Array[Float]] =
|
||||
(it: Iterator[Vector]) => {
|
||||
val mapper = (x: Vector) => {
|
||||
val (index, value) = x.toSeq.unzip
|
||||
LabeledPoint.fromSparseVector(0.0f,
|
||||
index.toArray, value.map(z => z.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it) yield mapper(x)
|
||||
val dmat = new DMatrix(dataIter, null)
|
||||
this.booster.predict(dmat)
|
||||
}
|
||||
data.mapPartition(predictMap)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user