[jvm-packages] Implemented early stopping (#2710)
* Allowed subsampling test from the training data frame/RDD The implementation requires storing 1 - trainTestRatio points in memory to make the sampling work. An alternative approach would be to construct the full DMatrix and then slice it deterministically into train/test. The peak memory consumption of such scenario, however, is twice the dataset size. * Removed duplication from 'XGBoost.train' Scala callers can (and should) use names to supply a subset of parameters. Method overloading is not required. * Reuse XGBoost seed parameter to stabilize train/test splitting * Added early stopping support to non-distributed XGBoost Closes #1544 * Added early-stopping to distributed XGBoost * Moved construction of 'watches' into a separate method This commit also fixes the handling of 'baseMargin' which previously was not added to the validation matrix. * Addressed review comments
This commit is contained in:
@@ -36,6 +36,9 @@ object XGBoost {
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRound if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* increases in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
@@ -45,44 +48,20 @@ object XGBoost {
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix],
|
||||
metrics: Array[Array[Float]],
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait): Booster = {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
watches: Map[String, DMatrix] = Map(),
|
||||
metrics: Array[Array[Float]] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
round, jWatches.asJava, metrics, obj, eval)
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Booster = {
|
||||
train(dtrain, params, round, watches, null, obj, eval)
|
||||
}
|
||||
|
||||
/**
|
||||
* Cross-validation with given parameters.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user