[jvm-packages] Scala/Java interface for Fast Histogram Algorithm (#1966)

* add back train method but mark as deprecated

* fix scalastyle error

* first commit in scala binding for fast histo

* java test

* add missed scala tests

* spark training

* add back train method but mark as deprecated

* fix scalastyle error

* local change

* first commit in scala binding for fast histo

* local change

* fix df frame test
This commit is contained in:
Nan Zhu
2017-03-04 15:37:24 -08:00
committed by GitHub
parent ac30a0aff5
commit ab13fd72bd
10 changed files with 400 additions and 37 deletions

View File

@@ -25,6 +25,41 @@ import scala.collection.JavaConverters._
* XGBoost Scala Training function.
*/
object XGBoost {
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param obj customized objective
* @param eval customized evaluation
* @return The trained booster.
*/
@throws(classOf[XGBoostError])
def train(
dtrain: DMatrix,
params: Map[String, Any],
round: Int,
watches: Map[String, DMatrix],
metrics: Array[Array[Float]],
obj: ObjectiveTrait,
eval: EvalTrait): Booster = {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
round, jWatches.asJava, metrics, obj, eval)
new Booster(xgboostInJava)
}
/**
* Train a booster given parameters.
*
@@ -45,16 +80,7 @@ object XGBoost {
watches: Map[String, DMatrix] = Map[String, DMatrix](),
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Booster = {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
round, jWatches.asJava,
obj, eval)
new Booster(xgboostInJava)
train(dtrain, params, round, watches, null, obj, eval)
}
/**