sketch of xgboost-spark

chooseBestBooster shall be in Boosters

remove tracker.py

rename XGBoost

remove cross-validation
This commit is contained in:
CodingCat
2015-12-22 03:29:20 -06:00
parent 4568692daf
commit 1540773340
14 changed files with 187 additions and 15 deletions

View File

@@ -15,6 +15,7 @@
*/
package ml.dmlc.xgboost4j;
import java.io.Serializable;
import java.util.List;
/**
@@ -22,7 +23,7 @@ import java.util.List;
*
* @author hzx
*/
public interface IObjective {
public interface IObjective extends Serializable {
/**
* user define objective function, return gradient and second order gradient
*

View File

@@ -143,7 +143,7 @@ public class XGBoost {
* @return evaluation history
* @throws XGBoostError native error
*/
public static String[] crossValiation(
public static String[] crossValidation(
Map<String, Object> params,
DMatrix data,
int round,

View File

@@ -35,7 +35,7 @@ object XGBoost {
new ScalaBoosterImpl(xgboostInJava)
}
def crossValiation(
def crossValidation(
params: Map[String, AnyRef],
data: DMatrix,
round: Int,
@@ -43,7 +43,7 @@ object XGBoost {
metrics: Array[String] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = {
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
}
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {

View File

@@ -130,6 +130,6 @@ public class BoosterImplTest {
//do 5-fold cross validation
int round = 2;
int nfold = 5;
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
}
}

View File

@@ -85,6 +85,6 @@ class ScalaBoosterImplSuite extends FunSuite {
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
val nfold = 5
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
}
}