adjust the API signature as well as the docs

This commit is contained in:
CodingCat
2016-03-11 15:22:44 -05:00
parent 97e4dcde98
commit 400b1faecc
23 changed files with 58 additions and 52 deletions

View File

@@ -60,8 +60,8 @@ public class XGBoost {
/**
* Train a booster with given parameters.
*
* @param params Booster params.
* @param dtrain Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
@@ -70,11 +70,13 @@ public class XGBoost {
* @return trained booster
* @throws XGBoostError native error
*/
public static Booster train(Map<String, Object> params,
DMatrix dtrain, int round,
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
@@ -139,8 +141,8 @@ public class XGBoost {
/**
* Cross-validation with given parameters.
*
* @param params Booster params.
* @param data Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
@@ -150,8 +152,8 @@ public class XGBoost {
* @throws XGBoostError native error
*/
public static String[] crossValidation(
Map<String, Object> params,
DMatrix data,
Map<String, Object> params,
int round,
int nfold,
String[] metrics,

View File

@@ -28,8 +28,8 @@ object XGBoost {
/**
* Train a booster given parameters.
*
* @param params Parameters.
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
@@ -39,8 +39,8 @@ object XGBoost {
*/
@throws(classOf[XGBoostError])
def train(
params: Map[String, Any],
dtrain: DMatrix,
params: Map[String, Any],
round: Int,
watches: Map[String, DMatrix] = Map[String, DMatrix](),
obj: ObjectiveTrait = null,
@@ -49,10 +49,11 @@ object XGBoost {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
params.map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
dtrain.jDMatrix, round, jWatches.asJava,
round, jWatches.asJava,
obj, eval)
new Booster(xgboostInJava)
}
@@ -60,8 +61,8 @@ object XGBoost {
/**
* Cross-validation with given parameters.
*
* @param params Booster params.
* @param data Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
@@ -71,17 +72,17 @@ object XGBoost {
*/
@throws(classOf[XGBoostError])
def crossValidation(
params: Map[String, Any],
data: DMatrix,
params: Map[String, Any],
round: Int,
nfold: Int = 5,
metrics: Array[String] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = {
JXGBoost.crossValidation(params.map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
data.jDMatrix, round, nfold, metrics, obj, eval)
JXGBoost.crossValidation(
data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}.
toMap[String, AnyRef].asJava,
round, nfold, metrics, obj, eval)
}
/**

View File

@@ -94,7 +94,7 @@ public class BoosterImplTest {
int round = 5;
//train a boost model
return XGBoost.train(paramMap, trainMat, round, watches, null, null);
return XGBoost.train(trainMat, paramMap, round, watches, null, null);
}
@Test
@@ -177,6 +177,6 @@ public class BoosterImplTest {
//do 5-fold cross validation
int round = 2;
int nfold = 5;
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
}
}

View File

@@ -74,7 +74,7 @@ class ScalaBoosterImplSuite extends FunSuite {
val watches = List("train" -> trainMat, "test" -> testMat).toMap
val round = 2
XGBoost.train(paramMap, trainMat, round, watches, null, null)
XGBoost.train(trainMat, paramMap, round, watches, null, null)
}
test("basic operation of booster") {
@@ -126,6 +126,6 @@ class ScalaBoosterImplSuite extends FunSuite {
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
val nfold = 5
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null)
}
}