From 8c220f51fc11e0de1807e4db7c986ca6496a5a1c Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 2 Mar 2016 17:21:42 -0500 Subject: [PATCH] add default values for Scala API --- .../scala/ml/dmlc/xgboost4j/scala/XGBoost.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 737e4765d..977b2397c 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -21,8 +21,13 @@ import ml.dmlc.xgboost4j.{XGBoost => JXGBoost} object XGBoost { - def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int, - watches: Map[String, DMatrix], obj: ObjectiveTrait, eval: EvalTrait): Booster = { + def train( + params: Map[String, AnyRef], + dtrain: DMatrix, + round: Int, + watches: Map[String, DMatrix] = Map[String, DMatrix](), + obj: ObjectiveTrait = null, + eval: EvalTrait = null): Booster = { val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava, obj, eval) @@ -33,10 +38,10 @@ object XGBoost { params: Map[String, AnyRef], data: DMatrix, round: Int, - nfold: Int, - metrics: Array[String], - obj: ObjectiveTrait, - eval: EvalTrait): Array[String] = { + nfold: Int = 5, + metrics: Array[String] = null, + obj: ObjectiveTrait = null, + eval: EvalTrait = null): Array[String] = { JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval) }