[jvm-packages] Scala/Java interface for Fast Histogram Algorithm (#1966)
* add back train method but mark as deprecated * fix scalastyle error * first commit in scala binding for fast histo * java test * add missed scala tests * spark training * add back train method but mark as deprecated * fix scalastyle error * local change * first commit in scala binding for fast histo * local change * fix df frame test
This commit is contained in:
@@ -180,6 +180,26 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return evalInfo[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrixs.
|
||||
*
|
||||
* @param evalMatrixs dmatrixs for evaluation
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @param metricsOut output array containing the evaluation metrics for each evalMatrix
|
||||
* @return eval information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, float[] metricsOut)
|
||||
throws XGBoostError {
|
||||
String stringFormat = evalSet(evalMatrixs, evalNames, iter);
|
||||
String[] metricPairs = stringFormat.split("\t");
|
||||
for (int i = 1; i < metricPairs.length; i++) {
|
||||
metricsOut[i - 1] = Float.valueOf(metricPairs[i].split(":")[1]);
|
||||
}
|
||||
return stringFormat;
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given customized Evaluation class
|
||||
*
|
||||
|
||||
@@ -57,26 +57,24 @@ public class XGBoost {
|
||||
return Booster.loadModel(in);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster with given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Booster params.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param obj customized objective (set to null if not used)
|
||||
* @param eval customized evaluation (set to null if not used)
|
||||
* @return trained booster
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
return train(dtrain, params, round, watches, null, obj, eval);
|
||||
}
|
||||
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval) throws XGBoostError {
|
||||
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
@@ -94,7 +92,7 @@ public class XGBoost {
|
||||
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
if (evalMats != null && evalMats.length > 0) {
|
||||
if (evalMats.length > 0) {
|
||||
allMats = new DMatrix[evalMats.length + 1];
|
||||
allMats[0] = dtrain;
|
||||
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
|
||||
@@ -121,12 +119,20 @@ public class XGBoost {
|
||||
}
|
||||
|
||||
//evaluation
|
||||
if (evalMats != null && evalMats.length > 0) {
|
||||
if (evalMats.length > 0) {
|
||||
String evalInfo;
|
||||
if (eval != null) {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, eval);
|
||||
} else {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||
if (metrics == null) {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||
} else {
|
||||
float[] m = new float[evalMats.length];
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, m);
|
||||
for (int i = 0; i < m.length; i++) {
|
||||
metrics[i][iter] = m[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Rabit.getRank() == 0) {
|
||||
Rabit.trackerPrint(evalInfo + '\n');
|
||||
|
||||
@@ -25,6 +25,41 @@ import scala.collection.JavaConverters._
|
||||
* XGBoost Scala Training function.
|
||||
*/
|
||||
object XGBoost {
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def train(
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
round: Int,
|
||||
watches: Map[String, DMatrix],
|
||||
metrics: Array[Array[Float]],
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait): Booster = {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
round, jWatches.asJava, metrics, obj, eval)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
@@ -45,16 +80,7 @@ object XGBoost {
|
||||
watches: Map[String, DMatrix] = Map[String, DMatrix](),
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Booster = {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).map{
|
||||
case (key: String, value) => (key, value.toString)
|
||||
}.toMap[String, AnyRef].asJava,
|
||||
round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new Booster(xgboostInJava)
|
||||
train(dtrain, params, round, watches, null, obj, eval)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user