adjust the API signature as well as the docs
This commit is contained in:
@@ -60,8 +60,8 @@ public class XGBoost {
|
||||
/**
|
||||
* Train a booster with given parameters.
|
||||
*
|
||||
* @param params Booster params.
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Booster params.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
@@ -70,11 +70,13 @@ public class XGBoost {
|
||||
* @return trained booster
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public static Booster train(Map<String, Object> params,
|
||||
DMatrix dtrain, int round,
|
||||
Map<String, DMatrix> watches,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
@@ -139,8 +141,8 @@ public class XGBoost {
|
||||
/**
|
||||
* Cross-validation with given parameters.
|
||||
*
|
||||
* @param params Booster params.
|
||||
* @param data Data to be trained.
|
||||
* @param params Booster params.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param nfold Number of folds in CV.
|
||||
* @param metrics Evaluation metrics to be watched in CV.
|
||||
@@ -150,8 +152,8 @@ public class XGBoost {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public static String[] crossValidation(
|
||||
Map<String, Object> params,
|
||||
DMatrix data,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
int nfold,
|
||||
String[] metrics,
|
||||
|
||||
@@ -28,8 +28,8 @@ object XGBoost {
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param params Parameters.
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
@@ -39,8 +39,8 @@ object XGBoost {
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(
|
||||
params: Map[String, Any],
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||
obj: ObjectiveTrait = null,
|
||||
@@ -49,10 +49,11 @@ object XGBoost {
|
||||
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
params.map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
dtrain.jDMatrix, round, jWatches.asJava,
|
||||
round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
@@ -60,8 +61,8 @@ object XGBoost {
|
||||
/**
|
||||
* Cross-validation with given parameters.
|
||||
*
|
||||
* @param params Booster params.
|
||||
* @param data Data to be trained.
|
||||
* @param params Booster params.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param nfold Number of folds in CV.
|
||||
* @param metrics Evaluation metrics to be watched in CV.
|
||||
@@ -71,17 +72,17 @@ object XGBoost {
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def crossValidation(
|
||||
params: Map[String, Any],
|
||||
data: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
nfold: Int = 5,
|
||||
metrics: Array[String] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Array[String] = {
|
||||
JXGBoost.crossValidation(params.map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
JXGBoost.crossValidation(
|
||||
data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}.
|
||||
toMap[String, AnyRef].asJava,
|
||||
round, nfold, metrics, obj, eval)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user