sketch of xgboost-spark
chooseBestBooster shall be in Boosters remove tracker.py rename XGBoost remove cross-validation
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@@ -22,7 +23,7 @@ import java.util.List;
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public interface IObjective {
|
||||
public interface IObjective extends Serializable {
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
*
|
||||
|
||||
@@ -143,7 +143,7 @@ public class XGBoost {
|
||||
* @return evaluation history
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public static String[] crossValiation(
|
||||
public static String[] crossValidation(
|
||||
Map<String, Object> params,
|
||||
DMatrix data,
|
||||
int round,
|
||||
|
||||
@@ -35,7 +35,7 @@ object XGBoost {
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
def crossValiation(
|
||||
def crossValidation(
|
||||
params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
@@ -43,7 +43,7 @@ object XGBoost {
|
||||
metrics: Array[String] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Array[String] = {
|
||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
}
|
||||
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
|
||||
@@ -130,6 +130,6 @@ public class BoosterImplTest {
|
||||
//do 5-fold cross validation
|
||||
int round = 2;
|
||||
int nfold = 5;
|
||||
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
|
||||
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +85,6 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||
val round = 2
|
||||
val nfold = 5
|
||||
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
|
||||
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user