adjust the API signature as well as the docs
This commit is contained in:
@@ -60,8 +60,8 @@ public class XGBoost {
|
||||
/**
|
||||
* Train a booster with given parameters.
|
||||
*
|
||||
* @param params Booster params.
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Booster params.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
@@ -70,11 +70,13 @@ public class XGBoost {
|
||||
* @return trained booster
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public static Booster train(Map<String, Object> params,
|
||||
DMatrix dtrain, int round,
|
||||
Map<String, DMatrix> watches,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
@@ -139,8 +141,8 @@ public class XGBoost {
|
||||
/**
|
||||
* Cross-validation with given parameters.
|
||||
*
|
||||
* @param params Booster params.
|
||||
* @param data Data to be trained.
|
||||
* @param params Booster params.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param nfold Number of folds in CV.
|
||||
* @param metrics Evaluation metrics to be watched in CV.
|
||||
@@ -150,8 +152,8 @@ public class XGBoost {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public static String[] crossValidation(
|
||||
Map<String, Object> params,
|
||||
DMatrix data,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
int nfold,
|
||||
String[] metrics,
|
||||
|
||||
@@ -28,8 +28,8 @@ object XGBoost {
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param params Parameters.
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
@@ -39,8 +39,8 @@ object XGBoost {
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(
|
||||
params: Map[String, Any],
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||
obj: ObjectiveTrait = null,
|
||||
@@ -49,10 +49,11 @@ object XGBoost {
|
||||
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
params.map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
dtrain.jDMatrix, round, jWatches.asJava,
|
||||
round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
@@ -60,8 +61,8 @@ object XGBoost {
|
||||
/**
|
||||
* Cross-validation with given parameters.
|
||||
*
|
||||
* @param params Booster params.
|
||||
* @param data Data to be trained.
|
||||
* @param params Booster params.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param nfold Number of folds in CV.
|
||||
* @param metrics Evaluation metrics to be watched in CV.
|
||||
@@ -71,17 +72,17 @@ object XGBoost {
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def crossValidation(
|
||||
params: Map[String, Any],
|
||||
data: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nfold: Int = 5,
|
||||
metrics: Array[String] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Array[String] = {
|
||||
JXGBoost.crossValidation(params.map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
JXGBoost.crossValidation(
|
||||
data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}.
|
||||
toMap[String, AnyRef].asJava,
|
||||
round, nfold, metrics, obj, eval)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -94,7 +94,7 @@ public class BoosterImplTest {
|
||||
int round = 5;
|
||||
|
||||
//train a boost model
|
||||
return XGBoost.train(paramMap, trainMat, round, watches, null, null);
|
||||
return XGBoost.train(trainMat, paramMap, round, watches, null, null);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -177,6 +177,6 @@ public class BoosterImplTest {
|
||||
//do 5-fold cross validation
|
||||
int round = 2;
|
||||
int nfold = 5;
|
||||
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
|
||||
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
val watches = List("train" -> trainMat, "test" -> testMat).toMap
|
||||
|
||||
val round = 2
|
||||
XGBoost.train(paramMap, trainMat, round, watches, null, null)
|
||||
XGBoost.train(trainMat, paramMap, round, watches, null, null)
|
||||
}
|
||||
|
||||
test("basic operation of booster") {
|
||||
@@ -126,6 +126,6 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||
val round = 2
|
||||
val nfold = 5
|
||||
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
|
||||
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user