example of DistTrainWithSpark and trigger job with foreachPartition
This commit is contained in:
parent
f768edfede
commit
808e30f9fc
@ -25,7 +25,7 @@
|
|||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j</artifactId>
|
<artifactId>xgboost4j-spark</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|||||||
@ -0,0 +1,74 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark.demo
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
import scala.io.Source
|
||||||
|
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.mllib.linalg.DenseVector
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||||
|
|
||||||
|
|
||||||
|
object DistTrainWithSpark {
|
||||||
|
|
||||||
|
private def readFile(filePath: String): List[LabeledPoint] = {
|
||||||
|
val file = Source.fromFile(new File(filePath))
|
||||||
|
val sampleList = new ListBuffer[LabeledPoint]
|
||||||
|
for (sample <- file.getLines()) {
|
||||||
|
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||||
|
}
|
||||||
|
sampleList.toList
|
||||||
|
}
|
||||||
|
|
||||||
|
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||||
|
val labelAndFeatures = line.split(" ")
|
||||||
|
val label = labelAndFeatures(0).toInt
|
||||||
|
val features = labelAndFeatures.tail
|
||||||
|
val denseFeature = new Array[Double](129)
|
||||||
|
for (feature <- features) {
|
||||||
|
val idAndValue = feature.split(":")
|
||||||
|
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||||
|
}
|
||||||
|
LabeledPoint(label, new DenseVector(denseFeature))
|
||||||
|
}
|
||||||
|
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
|
||||||
|
if (args.length != 4) {
|
||||||
|
println(
|
||||||
|
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
|
||||||
|
sys.exit(1)
|
||||||
|
}
|
||||||
|
val sc = new SparkContext()
|
||||||
|
val inputTrainPath = args(2)
|
||||||
|
val inputTestPath = args(3)
|
||||||
|
val trainingLabeledPoints = readFile(inputTrainPath)
|
||||||
|
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
|
||||||
|
val testLabeledPoints = readFile(inputTestPath).iterator
|
||||||
|
val testMatrix = new DMatrix(testLabeledPoints, null)
|
||||||
|
val booster = XGBoost.train(trainRDD,
|
||||||
|
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||||
|
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
|
||||||
|
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
|
|
||||||
private[spark] object DataUtils extends Serializable {
|
object DataUtils extends Serializable {
|
||||||
|
|
||||||
|
implicit def fromSparkToXGBoostLabeledPointsAsJava(
|
||||||
|
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
|
||||||
|
fromSparkToXGBoostLabeledPoints(sps).asJava
|
||||||
|
}
|
||||||
|
|
||||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
||||||
java.util.Iterator[LabeledPoint] = {
|
Iterator[LabeledPoint] = {
|
||||||
(for (p <- sps) yield {
|
for (p <- sps) yield {
|
||||||
p.features match {
|
p.features match {
|
||||||
case denseFeature: DenseVector =>
|
case denseFeature: DenseVector =>
|
||||||
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
|
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
|
||||||
@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
|
|||||||
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
|
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
|
||||||
sparseFeature.values.map(_.toFloat))
|
sparseFeature.values.map(_.toFloat))
|
||||||
}
|
}
|
||||||
}).asJava
|
}
|
||||||
}
|
|
||||||
|
|
||||||
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
|
|
||||||
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def fetchUpdateFromVector(feature: Vector) = feature match {
|
|
||||||
case denseFeature: DenseVector =>
|
|
||||||
fetchUpdateFromSparseVector(denseFeature.toSparse)
|
|
||||||
case sparseFeature: SparseVector =>
|
|
||||||
fetchUpdateFromSparseVector(sparseFeature)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -61,7 +61,8 @@ object XGBoost extends Serializable {
|
|||||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||||
boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval)
|
boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval)
|
||||||
// force the job
|
// force the job
|
||||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
boosters.foreachPartition(_ => ())
|
||||||
|
println("=====finished training=====")
|
||||||
val booster = boosters.first()
|
val booster = boosters.first()
|
||||||
val returnVal = tracker.waitFor()
|
val returnVal = tracker.waitFor()
|
||||||
logger.info(s"Rabit returns with exit code $returnVal")
|
logger.info(s"Rabit returns with exit code $returnVal")
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
package ml.dmlc.xgboost4j;
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Labeled data point for training examples.
|
* Labeled data point for training examples.
|
||||||
* Represent a sparse training instance.
|
* Represent a sparse training instance.
|
||||||
*/
|
*/
|
||||||
public class LabeledPoint {
|
public class LabeledPoint implements Serializable {
|
||||||
/** Label of the point */
|
/** Label of the point */
|
||||||
public float label;
|
public float label;
|
||||||
/** Weight of this data point */
|
/** Weight of this data point */
|
||||||
|
|||||||
@ -1,41 +1,140 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.*;
|
||||||
import java.io.Serializable;
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public interface Booster extends Serializable {
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Booster for xgboost, similar to the python wrapper xgboost.py
|
||||||
|
* but custom obj function and eval function not supported at present.
|
||||||
|
*
|
||||||
|
* @author hzx
|
||||||
|
*/
|
||||||
|
public class Booster implements Serializable {
|
||||||
|
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||||
|
|
||||||
|
long handle = 0;
|
||||||
|
|
||||||
|
//load native library
|
||||||
|
static {
|
||||||
|
try {
|
||||||
|
NativeLibLoader.initXgBoost();
|
||||||
|
} catch (IOException ex) {
|
||||||
|
logger.error("load native library failed.");
|
||||||
|
logger.error(ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* init Booster from dMatrixs
|
||||||
|
*
|
||||||
|
* @param params parameters
|
||||||
|
* @param dMatrixs DMatrix array
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
Booster(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
|
||||||
|
init(dMatrixs);
|
||||||
|
setParam("seed", "0");
|
||||||
|
setParams(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* load model from modelPath
|
||||||
|
*
|
||||||
|
* @param params parameters
|
||||||
|
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
Booster(Map<String, Object> params, String modelPath) throws XGBoostError {
|
||||||
|
init(null);
|
||||||
|
if (modelPath == null) {
|
||||||
|
throw new NullPointerException("modelPath : null");
|
||||||
|
}
|
||||||
|
loadModel(modelPath);
|
||||||
|
setParam("seed", "0");
|
||||||
|
setParams(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
||||||
|
long[] handles = null;
|
||||||
|
if (dMatrixs != null) {
|
||||||
|
handles = dmatrixsToHandles(dMatrixs);
|
||||||
|
}
|
||||||
|
long[] out = new long[1];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
|
||||||
|
|
||||||
|
handle = out[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set parameter
|
* set parameter
|
||||||
*
|
*
|
||||||
* @param key param name
|
* @param key param name
|
||||||
* @param value param value
|
* @param value param value
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
void setParam(String key, String value) throws XGBoostError;
|
public final void setParam(String key, String value) throws XGBoostError {
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set parameters
|
* set parameters
|
||||||
*
|
*
|
||||||
* @param params parameters key-value map
|
* @param params parameters key-value map
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
void setParams(Map<String, Object> params) throws XGBoostError;
|
public void setParams(Map<String, Object> params) throws XGBoostError {
|
||||||
|
if (params != null) {
|
||||||
|
for (Map.Entry<String, Object> entry : params.entrySet()) {
|
||||||
|
setParam(entry.getKey(), entry.getValue().toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update (one iteration)
|
* Update (one iteration)
|
||||||
*
|
*
|
||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @param iter current iteration number
|
* @param iter current iteration number
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
void update(DMatrix dtrain, int iter) throws XGBoostError;
|
public void update(DMatrix dtrain, int iter) throws XGBoostError {
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* update with customize obj func
|
* update with customize obj func
|
||||||
*
|
*
|
||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @param obj customized objective class
|
* @param obj customized objective class
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
void update(DMatrix dtrain, IObjective obj) throws XGBoostError;
|
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||||
|
float[][] predicts = predict(dtrain, true);
|
||||||
|
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||||
|
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* update with give grad and hess
|
* update with give grad and hess
|
||||||
@ -43,8 +142,16 @@ public interface Booster extends Serializable {
|
|||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @param grad first order of gradient
|
* @param grad first order of gradient
|
||||||
* @param hess seconde order of gradient
|
* @param hess seconde order of gradient
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError;
|
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||||
|
if (grad.length != hess.length) {
|
||||||
|
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||||
|
hess.length));
|
||||||
|
}
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||||
|
hess));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* evaluate with given dmatrixs.
|
* evaluate with given dmatrixs.
|
||||||
@ -53,8 +160,15 @@ public interface Booster extends Serializable {
|
|||||||
* @param evalNames name for eval dmatrixs, used for check results
|
* @param evalNames name for eval dmatrixs, used for check results
|
||||||
* @param iter current eval iteration
|
* @param iter current eval iteration
|
||||||
* @return eval information
|
* @return eval information
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError;
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
||||||
|
long[] handles = dmatrixsToHandles(evalMatrixs);
|
||||||
|
String[] evalInfo = new String[1];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||||
|
evalInfo));
|
||||||
|
return evalInfo[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* evaluate with given customized Evaluation class
|
* evaluate with given customized Evaluation class
|
||||||
@ -63,17 +177,64 @@ public interface Booster extends Serializable {
|
|||||||
* @param evalNames evaluation names
|
* @param evalNames evaluation names
|
||||||
* @param eval custom evaluator
|
* @param eval custom evaluator
|
||||||
* @return eval information
|
* @return eval information
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError;
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
|
||||||
|
throws XGBoostError {
|
||||||
|
String evalInfo = "";
|
||||||
|
for (int i = 0; i < evalNames.length; i++) {
|
||||||
|
String evalName = evalNames[i];
|
||||||
|
DMatrix evalMat = evalMatrixs[i];
|
||||||
|
float evalResult = eval.eval(predict(evalMat), evalMat);
|
||||||
|
String evalMetric = eval.getMetric();
|
||||||
|
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
|
||||||
|
}
|
||||||
|
return evalInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* base function for Predict
|
||||||
|
*
|
||||||
|
* @param data data
|
||||||
|
* @param outPutMargin output margin
|
||||||
|
* @param treeLimit limit number of trees
|
||||||
|
* @param predLeaf prediction minimum to keep leafs
|
||||||
|
* @return predict results
|
||||||
|
*/
|
||||||
|
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
|
||||||
|
boolean predLeaf) throws XGBoostError {
|
||||||
|
int optionMask = 0;
|
||||||
|
if (outPutMargin) {
|
||||||
|
optionMask = 1;
|
||||||
|
}
|
||||||
|
if (predLeaf) {
|
||||||
|
optionMask = 2;
|
||||||
|
}
|
||||||
|
float[][] rawPredicts = new float[1][];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||||
|
treeLimit, rawPredicts));
|
||||||
|
int row = (int) data.rowNum();
|
||||||
|
int col = rawPredicts[0].length / row;
|
||||||
|
float[][] predicts = new float[row][col];
|
||||||
|
int r, c;
|
||||||
|
for (int i = 0; i < rawPredicts[0].length; i++) {
|
||||||
|
r = i / col;
|
||||||
|
c = i % col;
|
||||||
|
predicts[r][c] = rawPredicts[0][i];
|
||||||
|
}
|
||||||
|
return predicts;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict with data
|
* Predict with data
|
||||||
*
|
*
|
||||||
* @param data dmatrix storing the input
|
* @param data dmatrix storing the input
|
||||||
* @return predict result
|
* @return predict result
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
float[][] predict(DMatrix data) throws XGBoostError;
|
public float[][] predict(DMatrix data) throws XGBoostError {
|
||||||
|
return pred(data, false, 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict with data
|
* Predict with data
|
||||||
@ -81,9 +242,11 @@ public interface Booster extends Serializable {
|
|||||||
* @param data dmatrix storing the input
|
* @param data dmatrix storing the input
|
||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||||
* @return predict result
|
* @return predict result
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError;
|
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
|
||||||
|
return pred(data, outPutMargin, 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict with data
|
* Predict with data
|
||||||
@ -92,31 +255,189 @@ public interface Booster extends Serializable {
|
|||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
* @return predict result
|
* @return predict result
|
||||||
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError;
|
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
|
||||||
|
return pred(data, outPutMargin, treeLimit, false);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict with data
|
* Predict with data
|
||||||
* @param data dmatrix storing the input
|
*
|
||||||
|
* @param data dmatrix storing the input
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
* nsample = data.numRow with each record indicating the predicted leaf index
|
||||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
* of each sample in each tree.
|
||||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||||
|
* in both tree 1 and tree 0.
|
||||||
* @return predict result
|
* @return predict result
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
|
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
|
||||||
|
return pred(data, false, treeLimit, predLeaf);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* save model to modelPath, the model path support depends on the path support
|
* save model to modelPath
|
||||||
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
|
*
|
||||||
* compiled with HDFS support.
|
|
||||||
* See also toByteArray
|
|
||||||
* @param modelPath model path
|
* @param modelPath model path
|
||||||
*/
|
*/
|
||||||
void saveModel(String modelPath) throws XGBoostError;
|
public void saveModel(String modelPath) throws XGBoostError{
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void loadModel(String modelPath) {
|
||||||
|
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get the dump of the model as a string array
|
||||||
|
*
|
||||||
|
* @param withStats Controls whether the split statistics are output.
|
||||||
|
* @return dumped model information
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
||||||
|
int statsFlag = 0;
|
||||||
|
if (withStats) {
|
||||||
|
statsFlag = 1;
|
||||||
|
}
|
||||||
|
String[][] modelInfos = new String[1][];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||||
|
return modelInfos[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get the dump of the model as a string array
|
||||||
|
*
|
||||||
|
* @param featureMap featureMap file
|
||||||
|
* @param withStats Controls whether the split statistics are output.
|
||||||
|
* @return dumped model information
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
|
||||||
|
int statsFlag = 0;
|
||||||
|
if (withStats) {
|
||||||
|
statsFlag = 1;
|
||||||
|
}
|
||||||
|
String[][] modelInfos = new String[1][];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||||
|
modelInfos));
|
||||||
|
return modelInfos[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dump model into a text file.
|
||||||
|
*
|
||||||
|
* @param modelPath file to save dumped model info
|
||||||
|
* @param withStats bool
|
||||||
|
* Controls whether the split statistics are output.
|
||||||
|
* @throws FileNotFoundException file not found
|
||||||
|
* @throws UnsupportedEncodingException unsupported feature
|
||||||
|
* @throws IOException error with model writing
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
|
||||||
|
File tf = new File(modelPath);
|
||||||
|
FileOutputStream out = new FileOutputStream(tf);
|
||||||
|
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||||
|
String[] modelInfos = getDumpInfo(withStats);
|
||||||
|
|
||||||
|
for (int i = 0; i < modelInfos.length; i++) {
|
||||||
|
writer.write("booster [" + i + "]:\n");
|
||||||
|
writer.write(modelInfos[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.close();
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dump model into a text file.
|
||||||
|
*
|
||||||
|
* @param modelPath file to save dumped model info
|
||||||
|
* @param featureMap featureMap file
|
||||||
|
* @param withStats bool
|
||||||
|
* Controls whether the split statistics are output.
|
||||||
|
* @throws FileNotFoundException exception
|
||||||
|
* @throws UnsupportedEncodingException exception
|
||||||
|
* @throws IOException exception
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
|
||||||
|
IOException, XGBoostError {
|
||||||
|
File tf = new File(modelPath);
|
||||||
|
FileOutputStream out = new FileOutputStream(tf);
|
||||||
|
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||||
|
String[] modelInfos = getDumpInfo(featureMap, withStats);
|
||||||
|
|
||||||
|
for (int i = 0; i < modelInfos.length; i++) {
|
||||||
|
writer.write("booster [" + i + "]:\n");
|
||||||
|
writer.write(modelInfos[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.close();
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get importance of each feature
|
||||||
|
*
|
||||||
|
* @return featureMap key: feature index, value: feature importance score
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public Map<String, Integer> getFeatureScore() throws XGBoostError {
|
||||||
|
String[] modelInfos = getDumpInfo(false);
|
||||||
|
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
||||||
|
for (String tree : modelInfos) {
|
||||||
|
for (String node : tree.split("\n")) {
|
||||||
|
String[] array = node.split("\\[");
|
||||||
|
if (array.length == 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String fid = array[1].split("\\]")[0];
|
||||||
|
fid = fid.split("<")[0];
|
||||||
|
if (featureScore.containsKey(fid)) {
|
||||||
|
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||||
|
} else {
|
||||||
|
featureScore.put(fid, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return featureScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get importance of each feature
|
||||||
|
*
|
||||||
|
* @param featureMap file to save dumped model info
|
||||||
|
* @return featureMap key: feature index, value: feature importance score
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
||||||
|
String[] modelInfos = getDumpInfo(featureMap, false);
|
||||||
|
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
||||||
|
for (String tree : modelInfos) {
|
||||||
|
for (String node : tree.split("\n")) {
|
||||||
|
String[] array = node.split("\\[");
|
||||||
|
if (array.length == 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String fid = array[1].split("\\]")[0];
|
||||||
|
fid = fid.split("<")[0];
|
||||||
|
if (featureScore.containsKey(fid)) {
|
||||||
|
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||||
|
} else {
|
||||||
|
featureScore.put(fid, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return featureScore;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save the model as byte array representation.
|
* Save the model as byte array representation.
|
||||||
@ -127,41 +448,77 @@ public interface Booster extends Serializable {
|
|||||||
* @return the saved byte array.
|
* @return the saved byte array.
|
||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
*/
|
*/
|
||||||
byte[] toByteArray() throws XGBoostError;
|
public byte[] toByteArray() throws XGBoostError {
|
||||||
|
byte[][] bytes = new byte[1][];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||||
|
return bytes[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dump model into a text file.
|
* Load the booster model from thread-local rabit checkpoint.
|
||||||
*
|
* This is only used in distributed training.
|
||||||
* @param modelPath file to save dumped model info
|
* @return the stored version number of the checkpoint.
|
||||||
* @param withStats bool Controls whether the split statistics are output.
|
* @throws XGBoostError
|
||||||
*/
|
*/
|
||||||
void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError;
|
int loadRabitCheckpoint() throws XGBoostError {
|
||||||
|
int[] out = new int[1];
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||||
|
return out[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dump model into a text file.
|
* Save the booster model into thread-local rabit checkpoint.
|
||||||
*
|
* This is only used in distributed training.
|
||||||
* @param modelPath file to save dumped model info
|
* @throws XGBoostError
|
||||||
* @param featureMap featureMap file
|
|
||||||
* @param withStats bool
|
|
||||||
* Controls whether the split statistics are output.
|
|
||||||
*/
|
*/
|
||||||
void dumpModel(String modelPath, String featureMap, boolean withStats)
|
void saveRabitCheckpoint() throws XGBoostError {
|
||||||
throws IOException, XGBoostError;
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get importance of each feature
|
* transfer DMatrix array to handle array (used for native functions)
|
||||||
*
|
*
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
* @param dmatrixs
|
||||||
|
* @return handle array for input dmatrixs
|
||||||
*/
|
*/
|
||||||
Map<String, Integer> getFeatureScore() throws XGBoostError ;
|
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
|
||||||
|
long[] handles = new long[dmatrixs.length];
|
||||||
|
for (int i = 0; i < dmatrixs.length; i++) {
|
||||||
|
handles[i] = dmatrixs[i].getHandle();
|
||||||
|
}
|
||||||
|
return handles;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
// making Booster serializable
|
||||||
* get importance of each feature
|
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||||
*
|
try {
|
||||||
* @param featureMap file to save dumped model info
|
out.writeObject(this.toByteArray());
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
} catch (XGBoostError ex) {
|
||||||
*/
|
throw new IOException(ex.toString());
|
||||||
Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void dispose();
|
private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
|
||||||
|
try {
|
||||||
|
this.init(null);
|
||||||
|
byte[] bytes = (byte[])in.readObject();
|
||||||
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
throw new IOException(ex.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void finalize() throws Throwable {
|
||||||
|
super.finalize();
|
||||||
|
dispose();
|
||||||
|
}
|
||||||
|
|
||||||
|
public synchronized void dispose() {
|
||||||
|
if (handle != 0L) {
|
||||||
|
XGBoostJNI.XGBoosterFree(handle);
|
||||||
|
handle = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||||
@ -56,7 +57,7 @@ class DataBatch {
|
|||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
static class BatchIterator implements Iterator<DataBatch> {
|
static class BatchIterator implements Iterator<DataBatch>, Serializable {
|
||||||
private Iterator<LabeledPoint> base;
|
private Iterator<LabeledPoint> base;
|
||||||
private int batchSize;
|
private int batchSize;
|
||||||
|
|
||||||
|
|||||||
@ -1,525 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
package ml.dmlc.xgboost4j.java;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
|
||||||
import org.apache.commons.logging.LogFactory;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Booster for xgboost, similar to the python wrapper xgboost.py
|
|
||||||
* but custom obj function and eval function not supported at present.
|
|
||||||
*
|
|
||||||
* @author hzx
|
|
||||||
*/
|
|
||||||
class JavaBoosterImpl implements Booster {
|
|
||||||
private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class);
|
|
||||||
|
|
||||||
long handle = 0;
|
|
||||||
|
|
||||||
//load native library
|
|
||||||
static {
|
|
||||||
try {
|
|
||||||
NativeLibLoader.initXgBoost();
|
|
||||||
} catch (IOException ex) {
|
|
||||||
logger.error("load native library failed.");
|
|
||||||
logger.error(ex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* init Booster from dMatrixs
|
|
||||||
*
|
|
||||||
* @param params parameters
|
|
||||||
* @param dMatrixs DMatrix array
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
JavaBoosterImpl(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
|
|
||||||
init(dMatrixs);
|
|
||||||
setParam("seed", "0");
|
|
||||||
setParams(params);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* load model from modelPath
|
|
||||||
*
|
|
||||||
* @param params parameters
|
|
||||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
JavaBoosterImpl(Map<String, Object> params, String modelPath) throws XGBoostError {
|
|
||||||
init(null);
|
|
||||||
if (modelPath == null) {
|
|
||||||
throw new NullPointerException("modelPath : null");
|
|
||||||
}
|
|
||||||
loadModel(modelPath);
|
|
||||||
setParam("seed", "0");
|
|
||||||
setParams(params);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
|
||||||
long[] handles = null;
|
|
||||||
if (dMatrixs != null) {
|
|
||||||
handles = dmatrixsToHandles(dMatrixs);
|
|
||||||
}
|
|
||||||
long[] out = new long[1];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
|
|
||||||
|
|
||||||
handle = out[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* set parameter
|
|
||||||
*
|
|
||||||
* @param key param name
|
|
||||||
* @param value param value
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public final void setParam(String key, String value) throws XGBoostError {
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* set parameters
|
|
||||||
*
|
|
||||||
* @param params parameters key-value map
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public void setParams(Map<String, Object> params) throws XGBoostError {
|
|
||||||
if (params != null) {
|
|
||||||
for (Map.Entry<String, Object> entry : params.entrySet()) {
|
|
||||||
setParam(entry.getKey(), entry.getValue().toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Update (one iteration)
|
|
||||||
*
|
|
||||||
* @param dtrain training data
|
|
||||||
* @param iter current iteration number
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public void update(DMatrix dtrain, int iter) throws XGBoostError {
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* update with customize obj func
|
|
||||||
*
|
|
||||||
* @param dtrain training data
|
|
||||||
* @param obj customized objective class
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
|
||||||
float[][] predicts = predict(dtrain, true);
|
|
||||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
|
||||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* update with give grad and hess
|
|
||||||
*
|
|
||||||
* @param dtrain training data
|
|
||||||
* @param grad first order of gradient
|
|
||||||
* @param hess seconde order of gradient
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
|
||||||
if (grad.length != hess.length) {
|
|
||||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
|
||||||
hess.length));
|
|
||||||
}
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
|
||||||
hess));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* evaluate with given dmatrixs.
|
|
||||||
*
|
|
||||||
* @param evalMatrixs dmatrixs for evaluation
|
|
||||||
* @param evalNames name for eval dmatrixs, used for check results
|
|
||||||
* @param iter current eval iteration
|
|
||||||
* @return eval information
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
|
||||||
long[] handles = dmatrixsToHandles(evalMatrixs);
|
|
||||||
String[] evalInfo = new String[1];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
|
||||||
evalInfo));
|
|
||||||
return evalInfo[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* evaluate with given customized Evaluation class
|
|
||||||
*
|
|
||||||
* @param evalMatrixs evaluation matrix
|
|
||||||
* @param evalNames evaluation names
|
|
||||||
* @param eval custom evaluator
|
|
||||||
* @return eval information
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
|
|
||||||
throws XGBoostError {
|
|
||||||
String evalInfo = "";
|
|
||||||
for (int i = 0; i < evalNames.length; i++) {
|
|
||||||
String evalName = evalNames[i];
|
|
||||||
DMatrix evalMat = evalMatrixs[i];
|
|
||||||
float evalResult = eval.eval(predict(evalMat), evalMat);
|
|
||||||
String evalMetric = eval.getMetric();
|
|
||||||
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
|
|
||||||
}
|
|
||||||
return evalInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* base function for Predict
|
|
||||||
*
|
|
||||||
* @param data data
|
|
||||||
* @param outPutMargin output margin
|
|
||||||
* @param treeLimit limit number of trees
|
|
||||||
* @param predLeaf prediction minimum to keep leafs
|
|
||||||
* @return predict results
|
|
||||||
*/
|
|
||||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
|
|
||||||
boolean predLeaf) throws XGBoostError {
|
|
||||||
int optionMask = 0;
|
|
||||||
if (outPutMargin) {
|
|
||||||
optionMask = 1;
|
|
||||||
}
|
|
||||||
if (predLeaf) {
|
|
||||||
optionMask = 2;
|
|
||||||
}
|
|
||||||
float[][] rawPredicts = new float[1][];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
|
||||||
treeLimit, rawPredicts));
|
|
||||||
int row = (int) data.rowNum();
|
|
||||||
int col = rawPredicts[0].length / row;
|
|
||||||
float[][] predicts = new float[row][col];
|
|
||||||
int r, c;
|
|
||||||
for (int i = 0; i < rawPredicts[0].length; i++) {
|
|
||||||
r = i / col;
|
|
||||||
c = i % col;
|
|
||||||
predicts[r][c] = rawPredicts[0][i];
|
|
||||||
}
|
|
||||||
return predicts;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict with data
|
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @return predict result
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public float[][] predict(DMatrix data) throws XGBoostError {
|
|
||||||
return pred(data, false, 0, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict with data
|
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
|
||||||
* @return predict result
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
|
|
||||||
return pred(data, outPutMargin, 0, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict with data
|
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
|
||||||
* @return predict result
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
|
|
||||||
return pred(data, outPutMargin, treeLimit, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict with data
|
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
|
||||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
|
||||||
* nsample = data.numRow with each record indicating the predicted leaf index
|
|
||||||
* of each sample in each tree.
|
|
||||||
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
|
||||||
* in both tree 1 and tree 0.
|
|
||||||
* @return predict result
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
|
|
||||||
return pred(data, false, treeLimit, predLeaf);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* save model to modelPath
|
|
||||||
*
|
|
||||||
* @param modelPath model path
|
|
||||||
*/
|
|
||||||
public void saveModel(String modelPath) throws XGBoostError{
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void loadModel(String modelPath) {
|
|
||||||
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get the dump of the model as a string array
|
|
||||||
*
|
|
||||||
* @param withStats Controls whether the split statistics are output.
|
|
||||||
* @return dumped model information
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
|
||||||
int statsFlag = 0;
|
|
||||||
if (withStats) {
|
|
||||||
statsFlag = 1;
|
|
||||||
}
|
|
||||||
String[][] modelInfos = new String[1][];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
|
||||||
return modelInfos[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get the dump of the model as a string array
|
|
||||||
*
|
|
||||||
* @param featureMap featureMap file
|
|
||||||
* @param withStats Controls whether the split statistics are output.
|
|
||||||
* @return dumped model information
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
|
|
||||||
int statsFlag = 0;
|
|
||||||
if (withStats) {
|
|
||||||
statsFlag = 1;
|
|
||||||
}
|
|
||||||
String[][] modelInfos = new String[1][];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
|
||||||
modelInfos));
|
|
||||||
return modelInfos[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dump model into a text file.
|
|
||||||
*
|
|
||||||
* @param modelPath file to save dumped model info
|
|
||||||
* @param withStats bool
|
|
||||||
* Controls whether the split statistics are output.
|
|
||||||
* @throws FileNotFoundException file not found
|
|
||||||
* @throws UnsupportedEncodingException unsupported feature
|
|
||||||
* @throws IOException error with model writing
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
|
|
||||||
File tf = new File(modelPath);
|
|
||||||
FileOutputStream out = new FileOutputStream(tf);
|
|
||||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
|
||||||
String[] modelInfos = getDumpInfo(withStats);
|
|
||||||
|
|
||||||
for (int i = 0; i < modelInfos.length; i++) {
|
|
||||||
writer.write("booster [" + i + "]:\n");
|
|
||||||
writer.write(modelInfos[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.close();
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dump model into a text file.
|
|
||||||
*
|
|
||||||
* @param modelPath file to save dumped model info
|
|
||||||
* @param featureMap featureMap file
|
|
||||||
* @param withStats bool
|
|
||||||
* Controls whether the split statistics are output.
|
|
||||||
* @throws FileNotFoundException exception
|
|
||||||
* @throws UnsupportedEncodingException exception
|
|
||||||
* @throws IOException exception
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
|
|
||||||
IOException, XGBoostError {
|
|
||||||
File tf = new File(modelPath);
|
|
||||||
FileOutputStream out = new FileOutputStream(tf);
|
|
||||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
|
||||||
String[] modelInfos = getDumpInfo(featureMap, withStats);
|
|
||||||
|
|
||||||
for (int i = 0; i < modelInfos.length; i++) {
|
|
||||||
writer.write("booster [" + i + "]:\n");
|
|
||||||
writer.write(modelInfos[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.close();
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get importance of each feature
|
|
||||||
*
|
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public Map<String, Integer> getFeatureScore() throws XGBoostError {
|
|
||||||
String[] modelInfos = getDumpInfo(false);
|
|
||||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
|
||||||
for (String tree : modelInfos) {
|
|
||||||
for (String node : tree.split("\n")) {
|
|
||||||
String[] array = node.split("\\[");
|
|
||||||
if (array.length == 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
String fid = array[1].split("\\]")[0];
|
|
||||||
fid = fid.split("<")[0];
|
|
||||||
if (featureScore.containsKey(fid)) {
|
|
||||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
|
||||||
} else {
|
|
||||||
featureScore.put(fid, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return featureScore;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get importance of each feature
|
|
||||||
*
|
|
||||||
* @param featureMap file to save dumped model info
|
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
|
||||||
String[] modelInfos = getDumpInfo(featureMap, false);
|
|
||||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
|
||||||
for (String tree : modelInfos) {
|
|
||||||
for (String node : tree.split("\n")) {
|
|
||||||
String[] array = node.split("\\[");
|
|
||||||
if (array.length == 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
String fid = array[1].split("\\]")[0];
|
|
||||||
fid = fid.split("<")[0];
|
|
||||||
if (featureScore.containsKey(fid)) {
|
|
||||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
|
||||||
} else {
|
|
||||||
featureScore.put(fid, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return featureScore;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save the model as byte array representation.
|
|
||||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
|
||||||
*
|
|
||||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
|
||||||
*
|
|
||||||
* @return the saved byte array.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
public byte[] toByteArray() throws XGBoostError {
|
|
||||||
byte[][] bytes = new byte[1][];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
|
||||||
return bytes[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load the booster model from thread-local rabit checkpoint.
|
|
||||||
* This is only used in distributed training.
|
|
||||||
* @return the stored version number of the checkpoint.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
int loadRabitCheckpoint() throws XGBoostError {
|
|
||||||
int[] out = new int[1];
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
|
||||||
return out[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save the booster model into thread-local rabit checkpoint.
|
|
||||||
* This is only used in distributed training.
|
|
||||||
* @throws XGBoostError
|
|
||||||
*/
|
|
||||||
void saveRabitCheckpoint() throws XGBoostError {
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* transfer DMatrix array to handle array (used for native functions)
|
|
||||||
*
|
|
||||||
* @param dmatrixs
|
|
||||||
* @return handle array for input dmatrixs
|
|
||||||
*/
|
|
||||||
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
|
|
||||||
long[] handles = new long[dmatrixs.length];
|
|
||||||
for (int i = 0; i < dmatrixs.length; i++) {
|
|
||||||
handles[i] = dmatrixs[i].getHandle();
|
|
||||||
}
|
|
||||||
return handles;
|
|
||||||
}
|
|
||||||
|
|
||||||
// making Booster serializable
|
|
||||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
|
||||||
try {
|
|
||||||
out.writeObject(this.toByteArray());
|
|
||||||
} catch (XGBoostError ex) {
|
|
||||||
throw new IOException(ex.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void readObject(java.io.ObjectInputStream in)
|
|
||||||
throws IOException, ClassNotFoundException {
|
|
||||||
try {
|
|
||||||
this.init(null);
|
|
||||||
byte[] bytes = (byte[])in.readObject();
|
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
|
||||||
} catch (XGBoostError ex) {
|
|
||||||
throw new IOException(ex.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void finalize() throws Throwable {
|
|
||||||
super.finalize();
|
|
||||||
dispose();
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void dispose() {
|
|
||||||
if (handle != 0L) {
|
|
||||||
XGBoostJNI.XGBoosterFree(handle);
|
|
||||||
handle = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,6 +1,7 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
@ -9,7 +10,7 @@ import org.apache.commons.logging.LogFactory;
|
|||||||
/**
|
/**
|
||||||
* Rabit global class for synchronization.
|
* Rabit global class for synchronization.
|
||||||
*/
|
*/
|
||||||
public class Rabit {
|
public class Rabit implements Serializable {
|
||||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
|
|||||||
@ -71,7 +71,7 @@ public class XGBoost {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//initialize booster
|
//initialize booster
|
||||||
JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
|
Booster booster = new Booster(params, allMats);
|
||||||
|
|
||||||
int version = booster.loadRabitCheckpoint();
|
int version = booster.loadRabitCheckpoint();
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ public class XGBoost {
|
|||||||
public static Booster initBoostingModel(
|
public static Booster initBoostingModel(
|
||||||
Map<String, Object> params,
|
Map<String, Object> params,
|
||||||
DMatrix[] dMatrixs) throws XGBoostError {
|
DMatrix[] dMatrixs) throws XGBoostError {
|
||||||
return new JavaBoosterImpl(params, dMatrixs);
|
return new Booster(params, dMatrixs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -127,7 +127,7 @@ public class XGBoost {
|
|||||||
*/
|
*/
|
||||||
public static Booster loadBoostModel(Map<String, Object> params, String modelPath)
|
public static Booster loadBoostModel(Map<String, Object> params, String modelPath)
|
||||||
throws XGBoostError {
|
throws XGBoostError {
|
||||||
return new JavaBoosterImpl(params, modelPath);
|
return new Booster(params, modelPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -16,172 +16,86 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import java.io.IOException
|
import ml.dmlc.xgboost4j.java
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
trait Booster extends Serializable {
|
class Booster private[xgboost4j](booster: java.Booster) extends Serializable {
|
||||||
|
|
||||||
|
def setParam(key: String, value: String): Unit = {
|
||||||
|
booster.setParam(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
def update(dtrain: DMatrix, iter: Int): Unit = {
|
||||||
|
booster.update(dtrain.jDMatrix, iter)
|
||||||
|
}
|
||||||
|
|
||||||
|
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||||
|
booster.update(dtrain.jDMatrix, obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
def dumpModel(modelPath: String, withStats: Boolean): Unit = {
|
||||||
|
booster.dumpModel(modelPath, withStats)
|
||||||
|
}
|
||||||
|
|
||||||
|
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
|
||||||
|
booster.dumpModel(modelPath, featureMap, withStats)
|
||||||
|
}
|
||||||
|
|
||||||
|
def setParams(params: Map[String, AnyRef]): Unit = {
|
||||||
|
booster.setParams(params.asJava)
|
||||||
|
}
|
||||||
|
|
||||||
|
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
|
||||||
|
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||||
|
}
|
||||||
|
|
||||||
|
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
|
||||||
|
String = {
|
||||||
|
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||||
|
}
|
||||||
|
|
||||||
|
def dispose: Unit = {
|
||||||
|
booster.dispose()
|
||||||
|
}
|
||||||
|
|
||||||
|
def predict(data: DMatrix): Array[Array[Float]] = {
|
||||||
|
booster.predict(data.jDMatrix)
|
||||||
|
}
|
||||||
|
|
||||||
|
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
|
||||||
|
booster.predict(data.jDMatrix, outPutMargin)
|
||||||
|
}
|
||||||
|
|
||||||
|
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
||||||
|
Array[Array[Float]] = {
|
||||||
|
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
|
||||||
|
booster.predict(data.jDMatrix, treeLimit, predLeaf)
|
||||||
|
}
|
||||||
|
|
||||||
|
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||||
|
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||||
|
}
|
||||||
|
|
||||||
|
def getFeatureScore: mutable.Map[String, Integer] = {
|
||||||
|
booster.getFeatureScore.asScala
|
||||||
|
}
|
||||||
|
|
||||||
|
def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
|
||||||
|
booster.getFeatureScore(featureMap).asScala
|
||||||
|
}
|
||||||
|
|
||||||
|
def saveModel(modelPath: String): Unit = {
|
||||||
|
booster.saveModel(modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def finalize(): Unit = {
|
||||||
|
super.finalize()
|
||||||
|
dispose
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* set parameter
|
|
||||||
*
|
|
||||||
* @param key param name
|
|
||||||
* @param value param value
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def setParam(key: String, value: String)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* set parameters
|
|
||||||
*
|
|
||||||
* @param params parameters key-value map
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def setParams(params: Map[String, AnyRef])
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Update (one iteration)
|
|
||||||
*
|
|
||||||
* @param dtrain training data
|
|
||||||
* @param iter current iteration number
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def update(dtrain: DMatrix, iter: Int)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* update with customize obj func
|
|
||||||
*
|
|
||||||
* @param dtrain training data
|
|
||||||
* @param obj customized objective class
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def update(dtrain: DMatrix, obj: ObjectiveTrait)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* update with give grad and hess
|
|
||||||
*
|
|
||||||
* @param dtrain training data
|
|
||||||
* @param grad first order of gradient
|
|
||||||
* @param hess seconde order of gradient
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float])
|
|
||||||
|
|
||||||
/**
|
|
||||||
* evaluate with given dmatrixs.
|
|
||||||
*
|
|
||||||
* @param evalMatrixs dmatrixs for evaluation
|
|
||||||
* @param evalNames name for eval dmatrixs, used for check results
|
|
||||||
* @param iter current eval iteration
|
|
||||||
* @return eval information
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String
|
|
||||||
|
|
||||||
/**
|
|
||||||
* evaluate with given customized Evaluation class
|
|
||||||
*
|
|
||||||
* @param evalMatrixs evaluation matrix
|
|
||||||
* @param evalNames evaluation names
|
|
||||||
* @param eval custom evaluator
|
|
||||||
* @return eval information
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict with data
|
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @return predict result
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def predict(data: DMatrix): Array[Array[Float]]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict with data
|
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
|
||||||
* @return predict result
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict with data
|
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
|
||||||
* @return predict result
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Predict with data
|
|
||||||
*
|
|
||||||
* @param data dmatrix storing the input
|
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
|
||||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
|
||||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
|
||||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
|
||||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
|
||||||
* @return predict result
|
|
||||||
* @throws XGBoostError native error
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* save model to modelPath
|
|
||||||
*
|
|
||||||
* @param modelPath model path
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def saveModel(modelPath: String)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dump model into a text file.
|
|
||||||
*
|
|
||||||
* @param modelPath file to save dumped model info
|
|
||||||
* @param withStats bool Controls whether the split statistics are output.
|
|
||||||
*/
|
|
||||||
@throws(classOf[IOException])
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def dumpModel(modelPath: String, withStats: Boolean)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dump model into a text file.
|
|
||||||
*
|
|
||||||
* @param modelPath file to save dumped model info
|
|
||||||
* @param featureMap featureMap file
|
|
||||||
* @param withStats bool
|
|
||||||
* Controls whether the split statistics are output.
|
|
||||||
*/
|
|
||||||
@throws(classOf[IOException])
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get importance of each feature
|
|
||||||
*
|
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def getFeatureScore: mutable.Map[String, Integer]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* get importance of each feature
|
|
||||||
*
|
|
||||||
* @param featureMap file to save dumped model info
|
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
|
||||||
*/
|
|
||||||
@throws(classOf[XGBoostError])
|
|
||||||
def getFeatureScore(featureMap: String): mutable.Map[String, Integer]
|
|
||||||
|
|
||||||
def dispose
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,99 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
import scala.collection.mutable
|
|
||||||
|
|
||||||
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster {
|
|
||||||
|
|
||||||
override def setParam(key: String, value: String): Unit = {
|
|
||||||
booster.setParam(key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def update(dtrain: DMatrix, iter: Int): Unit = {
|
|
||||||
booster.update(dtrain.jDMatrix, iter)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
|
||||||
booster.update(dtrain.jDMatrix, obj)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def dumpModel(modelPath: String, withStats: Boolean): Unit = {
|
|
||||||
booster.dumpModel(modelPath, withStats)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
|
|
||||||
booster.dumpModel(modelPath, featureMap, withStats)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def setParams(params: Map[String, AnyRef]): Unit = {
|
|
||||||
booster.setParams(params.asJava)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
|
|
||||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
|
|
||||||
String = {
|
|
||||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def dispose: Unit = {
|
|
||||||
booster.dispose()
|
|
||||||
}
|
|
||||||
|
|
||||||
override def predict(data: DMatrix): Array[Array[Float]] = {
|
|
||||||
booster.predict(data.jDMatrix)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
|
|
||||||
booster.predict(data.jDMatrix, outPutMargin)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
|
||||||
Array[Array[Float]] = {
|
|
||||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
|
|
||||||
booster.predict(data.jDMatrix, treeLimit, predLeaf)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
|
||||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def getFeatureScore: mutable.Map[String, Integer] = {
|
|
||||||
booster.getFeatureScore.asScala
|
|
||||||
}
|
|
||||||
|
|
||||||
override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
|
|
||||||
booster.getFeatureScore(featureMap).asScala
|
|
||||||
}
|
|
||||||
|
|
||||||
override def saveModel(modelPath: String): Unit = {
|
|
||||||
booster.saveModel(modelPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def finalize(): Unit = {
|
|
||||||
super.finalize()
|
|
||||||
dispose
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -16,9 +16,10 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
|
||||||
|
|
||||||
object XGBoost {
|
object XGBoost {
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -31,7 +32,7 @@ object XGBoost {
|
|||||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
||||||
obj, eval)
|
obj, eval)
|
||||||
new ScalaBoosterImpl(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
def crossValidation(
|
def crossValidation(
|
||||||
@ -47,11 +48,11 @@ object XGBoost {
|
|||||||
|
|
||||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||||
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
||||||
new ScalaBoosterImpl(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
|
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
|
||||||
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
|
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
|
||||||
new ScalaBoosterImpl(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user