adjust the API signature as well as the docs

This commit is contained in:
CodingCat
2016-03-11 15:22:44 -05:00
parent 97e4dcde98
commit 400b1faecc
23 changed files with 58 additions and 52 deletions

View File

@@ -60,8 +60,8 @@ public class XGBoost {
/**
* Train a booster with given parameters.
*
* @param params Booster params.
* @param dtrain Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
@@ -70,11 +70,13 @@ public class XGBoost {
* @return trained booster
* @throws XGBoostError native error
*/
public static Booster train(Map<String, Object> params,
DMatrix dtrain, int round,
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
@@ -139,8 +141,8 @@ public class XGBoost {
/**
* Cross-validation with given parameters.
*
* @param params Booster params.
* @param data Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
@@ -150,8 +152,8 @@ public class XGBoost {
* @throws XGBoostError native error
*/
public static String[] crossValidation(
Map<String, Object> params,
DMatrix data,
Map<String, Object> params,
int round,
int nfold,
String[] metrics,

View File

@@ -28,8 +28,8 @@ object XGBoost {
/**
* Train a booster given parameters.
*
* @param params Parameters.
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
@@ -39,8 +39,8 @@ object XGBoost {
*/
@throws(classOf[XGBoostError])
def train(
params: Map[String, Any],
dtrain: DMatrix,
params: Map[String, Any],
round: Int,
watches: Map[String, DMatrix] = Map[String, DMatrix](),
obj: ObjectiveTrait = null,
@@ -49,10 +49,11 @@ object XGBoost {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
params.map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
dtrain.jDMatrix, round, jWatches.asJava,
round, jWatches.asJava,
obj, eval)
new Booster(xgboostInJava)
}
@@ -60,8 +61,8 @@ object XGBoost {
/**
* Cross-validation with given parameters.
*
* @param params Booster params.
* @param data Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
@@ -71,17 +72,17 @@ object XGBoost {
*/
@throws(classOf[XGBoostError])
def crossValidation(
params: Map[String, Any],
data: DMatrix,
params: Map[String, Any],
round: Int,
nfold: Int = 5,
metrics: Array[String] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = {
JXGBoost.crossValidation(params.map{
case (key: String, value) => (key, value.toString)
}.toMap[String, AnyRef].asJava,
data.jDMatrix, round, nfold, metrics, obj, eval)
JXGBoost.crossValidation(
data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}.
toMap[String, AnyRef].asJava,
round, nfold, metrics, obj, eval)
}
/**