[jvm-packages] Scala/Java interface for Fast Histogram Algorithm (#1966)
* add back train method but mark as deprecated * fix scalastyle error * first commit in scala binding for fast histo * java test * add missed scala tests * spark training * add back train method but mark as deprecated * fix scalastyle error * local change * first commit in scala binding for fast histo * local change * fix df frame test
This commit is contained in:
@@ -25,6 +25,41 @@ import scala.collection.JavaConverters._
|
||||
* XGBoost Scala Training function.
|
||||
*/
|
||||
object XGBoost {
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix],
|
||||
metrics: Array[Array[Float]],
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait): Booster = {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
round, jWatches.asJava, metrics, obj, eval)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
@@ -45,16 +80,7 @@ object XGBoost {
|
||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Booster = {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new Booster(xgboostInJava)
|
||||
train(dtrain, params, round, watches, null, obj, eval)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user