[JVM] Refactor, add filesys API
This commit is contained in:
@@ -18,20 +18,23 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster}
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
trait Booster extends Serializable {
|
||||
|
||||
class Booster private[xgboost4j](booster: JBooster) extends Serializable {
|
||||
|
||||
/**
|
||||
* set parameter
|
||||
*
|
||||
* @param key param name
|
||||
* @param value param value
|
||||
*/
|
||||
* Set parameter to the Booster.
|
||||
*
|
||||
* @param key param name
|
||||
* @param value param value
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setParam(key: String, value: String)
|
||||
def setParam(key: String, value: AnyRef): Unit = {
|
||||
booster.setParam(key, value)
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameters
|
||||
@@ -39,7 +42,9 @@ trait Booster extends Serializable {
|
||||
* @param params parameters key-value map
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setParams(params: Map[String, AnyRef])
|
||||
def setParams(params: Map[String, AnyRef]): Unit = {
|
||||
booster.setParams(params.asJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Update (one iteration)
|
||||
@@ -48,7 +53,9 @@ trait Booster extends Serializable {
|
||||
* @param iter current iteration number
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, iter: Int)
|
||||
def update(dtrain: DMatrix, iter: Int): Unit = {
|
||||
booster.update(dtrain.jDMatrix, iter)
|
||||
}
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
@@ -57,7 +64,9 @@ trait Booster extends Serializable {
|
||||
* @param obj customized objective class
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait)
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
@@ -67,7 +76,9 @@ trait Booster extends Serializable {
|
||||
* @param hess seconde order of gradient
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float])
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrixs.
|
||||
@@ -78,7 +89,10 @@ trait Booster extends Serializable {
|
||||
* @return eval information
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int)
|
||||
: String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given customized Evaluation class
|
||||
@@ -89,26 +103,11 @@ trait Booster extends Serializable {
|
||||
* @return eval information
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait)
|
||||
: String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
@@ -119,22 +118,24 @@ trait Booster extends Serializable {
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]]
|
||||
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0)
|
||||
: Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
* Predict the leaf indices
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]]
|
||||
def predictLeaf(data: DMatrix, treeLimit: Int = 0)
|
||||
: Array[Array[Float]] = {
|
||||
booster.predictLeaf(data.jDMatrix, treeLimit)
|
||||
}
|
||||
|
||||
/**
|
||||
* save model to modelPath
|
||||
@@ -142,46 +143,50 @@ trait Booster extends Serializable {
|
||||
* @param modelPath model path
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def saveModel(modelPath: String)
|
||||
|
||||
def saveModel(modelPath: String): Unit = {
|
||||
booster.saveModel(modelPath)
|
||||
}
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param withStats bool Controls whether the split statistics are output.
|
||||
*/
|
||||
@throws(classOf[IOException])
|
||||
* save model to Output stream
|
||||
*
|
||||
* @param out Output stream
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def dumpModel(modelPath: String, withStats: Boolean)
|
||||
|
||||
def saveModel(out: java.io.OutputStream): Unit = {
|
||||
booster.saveModel(out)
|
||||
}
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
* Dump model as Array of string
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
*/
|
||||
@throws(classOf[IOException])
|
||||
@throws(classOf[XGBoostError])
|
||||
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean)
|
||||
def getModelDump(featureMap: String = null, withStats: Boolean = false)
|
||||
: Array[String] = {
|
||||
booster.getModelDump(featureMap, withStats)
|
||||
}
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
* Get importance of each feature
|
||||
*
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureScore: mutable.Map[String, Integer]
|
||||
def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore(featureMap).asScala
|
||||
}
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @param featureMap file to save dumped model info
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureScore(featureMap: String): mutable.Map[String, Integer]
|
||||
* Dispose the booster when it is no longer needed
|
||||
*/
|
||||
def dispose: Unit = {
|
||||
booster.dispose()
|
||||
}
|
||||
|
||||
def dispose
|
||||
override def finalize(): Unit = {
|
||||
super.finalize()
|
||||
dispose
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster {
|
||||
|
||||
override def setParam(key: String, value: String): Unit = {
|
||||
booster.setParam(key, value)
|
||||
}
|
||||
|
||||
override def update(dtrain: DMatrix, iter: Int): Unit = {
|
||||
booster.update(dtrain.jDMatrix, iter)
|
||||
}
|
||||
|
||||
override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
}
|
||||
|
||||
override def dumpModel(modelPath: String, withStats: Boolean): Unit = {
|
||||
booster.dumpModel(modelPath, withStats)
|
||||
}
|
||||
|
||||
override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
|
||||
booster.dumpModel(modelPath, featureMap, withStats)
|
||||
}
|
||||
|
||||
override def setParams(params: Map[String, AnyRef]): Unit = {
|
||||
booster.setParams(params.asJava)
|
||||
}
|
||||
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||
}
|
||||
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
|
||||
String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||
}
|
||||
|
||||
override def dispose: Unit = {
|
||||
booster.dispose()
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
||||
Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, treeLimit, predLeaf)
|
||||
}
|
||||
|
||||
override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
}
|
||||
|
||||
override def getFeatureScore: mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore.asScala
|
||||
}
|
||||
|
||||
override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore(featureMap).asScala
|
||||
}
|
||||
|
||||
override def saveModel(modelPath: String): Unit = {
|
||||
booster.saveModel(modelPath)
|
||||
}
|
||||
|
||||
override def finalize(): Unit = {
|
||||
super.finalize()
|
||||
dispose
|
||||
}
|
||||
}
|
||||
@@ -16,11 +16,28 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
|
||||
import java.io.InputStream
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
/**
|
||||
* XGBoost Scala Training function.
|
||||
*/
|
||||
object XGBoost {
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param params Parameters.
|
||||
* @param dtrain Data to be trained.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(
|
||||
params: Map[String, AnyRef],
|
||||
dtrain: DMatrix,
|
||||
@@ -31,9 +48,22 @@ object XGBoost {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Cross-validation with given parameters.
|
||||
*
|
||||
* @param params Booster params.
|
||||
* @param data Data to be trained.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param nfold Number of folds in CV.
|
||||
* @param metrics Evaluation metrics to be watched in CV.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return evaluation history
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def crossValidation(
|
||||
params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
@@ -45,13 +75,28 @@ object XGBoost {
|
||||
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
}
|
||||
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
/**
|
||||
* load model from modelPath
|
||||
*
|
||||
* @param modelPath booster modelPath
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def loadModel(modelPath: String): Booster = {
|
||||
val xgboostInJava = JXGBoost.loadModel(modelPath)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
|
||||
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
/**
|
||||
* Load a new Booster model from a file opened as input stream.
|
||||
* The assumption is the input stream only contains one XGBoost Model.
|
||||
* This can be used to load existing booster models saved by other XGBoost bindings.
|
||||
*
|
||||
* @param in The input stream of the file.
|
||||
* @return The create booster
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def loadModel(in: InputStream): Booster = {
|
||||
val xgboostInJava = JXGBoost.loadModel(in)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user