re-structure Java API, add Scala API and consolidate the names of Java/Scala API
This commit is contained in:
@@ -0,0 +1,172 @@
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.dmlc.xgboost4j.{IEvaluation, IObjective, XGBoostError}
|
||||
|
||||
trait Booster {
|
||||
|
||||
|
||||
/**
|
||||
* set parameter
|
||||
*
|
||||
* @param key param name
|
||||
* @param value param value
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setParam(key: String, value: String)
|
||||
|
||||
/**
|
||||
* set parameters
|
||||
*
|
||||
* @param params parameters key-value map
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setParams(params: Map[String, AnyRef])
|
||||
|
||||
/**
|
||||
* Update (one iteration)
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter current iteration number
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, iter: Int)
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param obj customized objective class
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, obj: IObjective)
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float])
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrixs.
|
||||
*
|
||||
* @param evalMatrixs dmatrixs for evaluation
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @return eval information
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String
|
||||
|
||||
/**
|
||||
* evaluate with given customized Evaluation class
|
||||
*
|
||||
* @param evalMatrixs evaluation matrix
|
||||
* @param evalNames evaluation names
|
||||
* @param eval custom evaluator
|
||||
* @return eval information
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation): String
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
|
||||
with each record indicating the predicted leaf index of each sample in each tree.
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* save model to modelPath
|
||||
*
|
||||
* @param modelPath model path
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def saveModel(modelPath: String)
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param withStats bool Controls whether the split statistics are output.
|
||||
*/
|
||||
@throws(classOf[IOException])
|
||||
@throws(classOf[XGBoostError])
|
||||
def dumpModel(modelPath: String, withStats: Boolean)
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
*/
|
||||
@throws(classOf[IOException])
|
||||
@throws(classOf[XGBoostError])
|
||||
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean)
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureScore: mutable.Map[String, Integer]
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @param featureMap file to save dumped model info
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureScore(featureMap: String): mutable.Map[String, Integer]
|
||||
|
||||
def dispose
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import org.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||
|
||||
class DMatrix private(private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* init DMatrix from file (svmlight format)
|
||||
*
|
||||
* @param dataPath path of data file
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
def this(dataPath: String) {
|
||||
this(new JDMatrix(dataPath))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from sparse matrix
|
||||
*
|
||||
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
|
||||
this(new JDMatrix(headers, indices, data, st))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
*
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(data: Array[Float], nrow: Int, ncol: Int) {
|
||||
this(new JDMatrix(data, nrow, ncol))
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
*
|
||||
* @param labels labels
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setLabel(labels: Array[Float]): Unit = {
|
||||
jDMatrix.setLabel(labels)
|
||||
}
|
||||
|
||||
/**
|
||||
* set weight of each instance
|
||||
*
|
||||
* @param weights weights
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setWeight(weights: Array[Float]): Unit = {
|
||||
jDMatrix.setWeight(weights)
|
||||
}
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
*
|
||||
* @param baseMargin base margin
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setBaseMargin(baseMargin: Array[Float]): Unit = {
|
||||
jDMatrix.setBaseMargin(baseMargin)
|
||||
}
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
*
|
||||
* @param baseMargin base margin
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setBaseMargin(baseMargin: Array[Array[Float]]): Unit = {
|
||||
jDMatrix.setBaseMargin(baseMargin)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set group sizes of DMatrix (used for ranking)
|
||||
*
|
||||
* @param group group size as array
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setGroup(group: Array[Int]): Unit = {
|
||||
jDMatrix.setGroup(group)
|
||||
}
|
||||
|
||||
/**
|
||||
* get label values
|
||||
*
|
||||
* @return label
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getLabel: Array[Float] = {
|
||||
jDMatrix.getLabel
|
||||
}
|
||||
|
||||
/**
|
||||
* get weight of the DMatrix
|
||||
*
|
||||
* @return weights
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getWeight: Array[Float] = {
|
||||
jDMatrix.getWeight
|
||||
}
|
||||
|
||||
/**
|
||||
* get base margin of the DMatrix
|
||||
*
|
||||
* @return base margin
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getBaseMargin: Array[Float] = {
|
||||
jDMatrix.getBaseMargin
|
||||
}
|
||||
|
||||
/**
|
||||
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
||||
*
|
||||
* @param rowIndex row index
|
||||
* @return sliced new DMatrix
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def slice(rowIndex: Array[Int]): DMatrix = {
|
||||
new DMatrix(jDMatrix.slice(rowIndex))
|
||||
}
|
||||
|
||||
/**
|
||||
* get the row number of DMatrix
|
||||
*
|
||||
* @return number of rows
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def rowNum: Long = {
|
||||
jDMatrix.rowNum
|
||||
}
|
||||
|
||||
/**
|
||||
* save DMatrix to filePath
|
||||
*
|
||||
* @param filePath file path
|
||||
*/
|
||||
def saveBinary(filePath: String): Unit = {
|
||||
jDMatrix.saveBinary(filePath)
|
||||
}
|
||||
|
||||
def getHandle: Long = {
|
||||
jDMatrix.getHandle
|
||||
}
|
||||
|
||||
def delete(): Unit = {
|
||||
jDMatrix.dispose()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.dmlc.xgboost4j.{Booster => JBooster, IEvaluation, IObjective}
|
||||
|
||||
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) extends Booster {
|
||||
|
||||
override def setParam(key: String, value: String): Unit = {
|
||||
booster.setParam(key, value)
|
||||
}
|
||||
|
||||
override def update(dtrain: DMatrix, iter: Int): Unit = {
|
||||
booster.update(dtrain.jDMatrix, iter)
|
||||
}
|
||||
|
||||
override def update(dtrain: DMatrix, obj: IObjective): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
}
|
||||
|
||||
override def dumpModel(modelPath: String, withStats: Boolean): Unit = {
|
||||
booster.dumpModel(modelPath, withStats)
|
||||
}
|
||||
|
||||
override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
|
||||
booster.dumpModel(modelPath, featureMap, withStats)
|
||||
}
|
||||
|
||||
override def setParams(params: Map[String, AnyRef]): Unit = {
|
||||
booster.setParams(params.asJava)
|
||||
}
|
||||
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||
}
|
||||
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation): String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||
}
|
||||
|
||||
override def dispose: Unit = {
|
||||
booster.dispose()
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, treeLimit, predLeaf)
|
||||
}
|
||||
|
||||
override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
}
|
||||
|
||||
override def getFeatureScore: mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore.asScala
|
||||
}
|
||||
|
||||
override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore(featureMap).asScala
|
||||
}
|
||||
|
||||
override def saveModel(modelPath: String): Unit = {
|
||||
booster.saveModel(modelPath)
|
||||
}
|
||||
|
||||
override def finalize(): Unit = {
|
||||
super.finalize()
|
||||
dispose
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import org.dmlc.xgboost4j
|
||||
import org.dmlc.xgboost4j.{XGBoost => JXGBoost, IEvaluation, IObjective}
|
||||
|
||||
object XGBoost {
|
||||
|
||||
def train(params: Map[String, AnyRef], dtrain: xgboost4j.DMatrix, round: Int,
|
||||
watches: Map[String, xgboost4j.DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
|
||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain, round, watches.asJava, obj, eval)
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
def crossValiation(params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
nfold: Int,
|
||||
metrics: Array[String],
|
||||
obj: IObjective,
|
||||
eval: IEvaluation): Array[String] = {
|
||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj,
|
||||
eval)
|
||||
}
|
||||
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
|
||||
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user