example of DistTrainWithSpark and trigger job with foreachPartition
This commit is contained in:
parent
f768edfede
commit
808e30f9fc
@ -25,7 +25,7 @@
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<artifactId>xgboost4j-spark</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
|
||||
@ -0,0 +1,74 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.demo
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.linalg.DenseVector
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
|
||||
|
||||
object DistTrainWithSpark {
|
||||
|
||||
private def readFile(filePath: String): List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||
}
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||
val labelAndFeatures = line.split(" ")
|
||||
val label = labelAndFeatures(0).toInt
|
||||
val features = labelAndFeatures.tail
|
||||
val denseFeature = new Array[Double](129)
|
||||
for (feature <- features) {
|
||||
val idAndValue = feature.split(":")
|
||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||
}
|
||||
LabeledPoint(label, new DenseVector(denseFeature))
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
|
||||
if (args.length != 4) {
|
||||
println(
|
||||
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
val sc = new SparkContext()
|
||||
val inputTrainPath = args(2)
|
||||
val inputTestPath = args(3)
|
||||
val trainingLabeledPoints = readFile(inputTrainPath)
|
||||
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
|
||||
val testLabeledPoints = readFile(inputTestPath).iterator
|
||||
val testMatrix = new DMatrix(testLabeledPoints, null)
|
||||
val booster = XGBoost.train(trainRDD,
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
|
||||
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
|
||||
}
|
||||
}
|
||||
@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
|
||||
private[spark] object DataUtils extends Serializable {
|
||||
object DataUtils extends Serializable {
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPointsAsJava(
|
||||
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
|
||||
fromSparkToXGBoostLabeledPoints(sps).asJava
|
||||
}
|
||||
|
||||
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
|
||||
java.util.Iterator[LabeledPoint] = {
|
||||
(for (p <- sps) yield {
|
||||
Iterator[LabeledPoint] = {
|
||||
for (p <- sps) yield {
|
||||
p.features match {
|
||||
case denseFeature: DenseVector =>
|
||||
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
|
||||
@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
|
||||
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
|
||||
sparseFeature.values.map(_.toFloat))
|
||||
}
|
||||
}).asJava
|
||||
}
|
||||
|
||||
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
|
||||
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
|
||||
}
|
||||
|
||||
private def fetchUpdateFromVector(feature: Vector) = feature match {
|
||||
case denseFeature: DenseVector =>
|
||||
fetchUpdateFromSparseVector(denseFeature.toSparse)
|
||||
case sparseFeature: SparseVector =>
|
||||
fetchUpdateFromSparseVector(sparseFeature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,7 +61,8 @@ object XGBoost extends Serializable {
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval)
|
||||
// force the job
|
||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||
boosters.foreachPartition(_ => ())
|
||||
println("=====finished training=====")
|
||||
val booster = boosters.first()
|
||||
val returnVal = tracker.waitFor()
|
||||
logger.info(s"Rabit returns with exit code $returnVal")
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* Labeled data point for training examples.
|
||||
* Represent a sparse training instance.
|
||||
*/
|
||||
public class LabeledPoint {
|
||||
public class LabeledPoint implements Serializable {
|
||||
/** Label of the point */
|
||||
public float label;
|
||||
/** Weight of this data point */
|
||||
|
||||
@ -1,41 +1,140 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface Booster extends Serializable {
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Booster for xgboost, similar to the python wrapper xgboost.py
|
||||
* but custom obj function and eval function not supported at present.
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class Booster implements Serializable {
|
||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||
|
||||
long handle = 0;
|
||||
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* init Booster from dMatrixs
|
||||
*
|
||||
* @param params parameters
|
||||
* @param dMatrixs DMatrix array
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
Booster(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
|
||||
init(dMatrixs);
|
||||
setParam("seed", "0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
/**
|
||||
* load model from modelPath
|
||||
*
|
||||
* @param params parameters
|
||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
Booster(Map<String, Object> params, String modelPath) throws XGBoostError {
|
||||
init(null);
|
||||
if (modelPath == null) {
|
||||
throw new NullPointerException("modelPath : null");
|
||||
}
|
||||
loadModel(modelPath);
|
||||
setParam("seed", "0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
||||
long[] handles = null;
|
||||
if (dMatrixs != null) {
|
||||
handles = dmatrixsToHandles(dMatrixs);
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
|
||||
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameter
|
||||
*
|
||||
* @param key param name
|
||||
* @param value param value
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
void setParam(String key, String value) throws XGBoostError;
|
||||
public final void setParam(String key, String value) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameters
|
||||
*
|
||||
* @param params parameters key-value map
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
void setParams(Map<String, Object> params) throws XGBoostError;
|
||||
public void setParams(Map<String, Object> params) throws XGBoostError {
|
||||
if (params != null) {
|
||||
for (Map.Entry<String, Object> entry : params.entrySet()) {
|
||||
setParam(entry.getKey(), entry.getValue().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Update (one iteration)
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter current iteration number
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
void update(DMatrix dtrain, int iter) throws XGBoostError;
|
||||
public void update(DMatrix dtrain, int iter) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param obj customized objective class
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
void update(DMatrix dtrain, IObjective obj) throws XGBoostError;
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = predict(dtrain, true);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
@ -43,8 +142,16 @@ public interface Booster extends Serializable {
|
||||
* @param dtrain training data
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError;
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||
if (grad.length != hess.length) {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||
hess));
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrixs.
|
||||
@ -53,8 +160,15 @@ public interface Booster extends Serializable {
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @return eval information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError;
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
||||
long[] handles = dmatrixsToHandles(evalMatrixs);
|
||||
String[] evalInfo = new String[1];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||
evalInfo));
|
||||
return evalInfo[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given customized Evaluation class
|
||||
@ -63,17 +177,64 @@ public interface Booster extends Serializable {
|
||||
* @param evalNames evaluation names
|
||||
* @param eval custom evaluator
|
||||
* @return eval information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError;
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
|
||||
throws XGBoostError {
|
||||
String evalInfo = "";
|
||||
for (int i = 0; i < evalNames.length; i++) {
|
||||
String evalName = evalNames[i];
|
||||
DMatrix evalMat = evalMatrixs[i];
|
||||
float evalResult = eval.eval(predict(evalMat), evalMat);
|
||||
String evalMetric = eval.getMetric();
|
||||
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
|
||||
}
|
||||
return evalInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* base function for Predict
|
||||
*
|
||||
* @param data data
|
||||
* @param outPutMargin output margin
|
||||
* @param treeLimit limit number of trees
|
||||
* @param predLeaf prediction minimum to keep leafs
|
||||
* @return predict results
|
||||
*/
|
||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
|
||||
boolean predLeaf) throws XGBoostError {
|
||||
int optionMask = 0;
|
||||
if (outPutMargin) {
|
||||
optionMask = 1;
|
||||
}
|
||||
if (predLeaf) {
|
||||
optionMask = 2;
|
||||
}
|
||||
float[][] rawPredicts = new float[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||
treeLimit, rawPredicts));
|
||||
int row = (int) data.rowNum();
|
||||
int col = rawPredicts[0].length / row;
|
||||
float[][] predicts = new float[row][col];
|
||||
int r, c;
|
||||
for (int i = 0; i < rawPredicts[0].length; i++) {
|
||||
r = i / col;
|
||||
c = i % col;
|
||||
predicts[r][c] = rawPredicts[0][i];
|
||||
}
|
||||
return predicts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
float[][] predict(DMatrix data) throws XGBoostError;
|
||||
|
||||
public float[][] predict(DMatrix data) throws XGBoostError {
|
||||
return pred(data, false, 0, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
@ -81,9 +242,11 @@ public interface Booster extends Serializable {
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError;
|
||||
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
|
||||
return pred(data, outPutMargin, 0, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
@ -92,31 +255,189 @@ public interface Booster extends Serializable {
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError;
|
||||
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
|
||||
return pred(data, outPutMargin, treeLimit, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
* @param data dmatrix storing the input
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index
|
||||
* of each sample in each tree.
|
||||
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
* in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
|
||||
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
|
||||
return pred(data, false, treeLimit, predLeaf);
|
||||
}
|
||||
|
||||
/**
|
||||
* save model to modelPath, the model path support depends on the path support
|
||||
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
|
||||
* compiled with HDFS support.
|
||||
* See also toByteArray
|
||||
* save model to modelPath
|
||||
*
|
||||
* @param modelPath model path
|
||||
*/
|
||||
void saveModel(String modelPath) throws XGBoostError;
|
||||
public void saveModel(String modelPath) throws XGBoostError{
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
|
||||
}
|
||||
|
||||
private void loadModel(String modelPath) {
|
||||
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* get the dump of the model as a string array
|
||||
*
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get the dump of the model as a string array
|
||||
*
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
* @throws FileNotFoundException file not found
|
||||
* @throws UnsupportedEncodingException unsupported feature
|
||||
* @throws IOException error with model writing
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
|
||||
File tf = new File(modelPath);
|
||||
FileOutputStream out = new FileOutputStream(tf);
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||
String[] modelInfos = getDumpInfo(withStats);
|
||||
|
||||
for (int i = 0; i < modelInfos.length; i++) {
|
||||
writer.write("booster [" + i + "]:\n");
|
||||
writer.write(modelInfos[i]);
|
||||
}
|
||||
|
||||
writer.close();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
* @throws FileNotFoundException exception
|
||||
* @throws UnsupportedEncodingException exception
|
||||
* @throws IOException exception
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
|
||||
IOException, XGBoostError {
|
||||
File tf = new File(modelPath);
|
||||
FileOutputStream out = new FileOutputStream(tf);
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||
String[] modelInfos = getDumpInfo(featureMap, withStats);
|
||||
|
||||
for (int i = 0; i < modelInfos.length; i++) {
|
||||
writer.write("booster [" + i + "]:\n");
|
||||
writer.write(modelInfos[i]);
|
||||
}
|
||||
|
||||
writer.close();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore() throws XGBoostError {
|
||||
String[] modelInfos = getDumpInfo(false);
|
||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
||||
for (String tree : modelInfos) {
|
||||
for (String node : tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
if (array.length == 1) {
|
||||
continue;
|
||||
}
|
||||
String fid = array[1].split("\\]")[0];
|
||||
fid = fid.split("<")[0];
|
||||
if (featureScore.containsKey(fid)) {
|
||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||
} else {
|
||||
featureScore.put(fid, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @param featureMap file to save dumped model info
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
||||
String[] modelInfos = getDumpInfo(featureMap, false);
|
||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
||||
for (String tree : modelInfos) {
|
||||
for (String node : tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
if (array.length == 1) {
|
||||
continue;
|
||||
}
|
||||
String fid = array[1].split("\\]")[0];
|
||||
fid = fid.split("<")[0];
|
||||
if (featureScore.containsKey(fid)) {
|
||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||
} else {
|
||||
featureScore.put(fid, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as byte array representation.
|
||||
@ -127,41 +448,77 @@ public interface Booster extends Serializable {
|
||||
* @return the saved byte array.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
byte[] toByteArray() throws XGBoostError;
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
byte[][] bytes = new byte[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param withStats bool Controls whether the split statistics are output.
|
||||
* Load the booster model from thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @return the stored version number of the checkpoint.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError;
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
* Save the booster model into thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void dumpModel(String modelPath, String featureMap, boolean withStats)
|
||||
throws IOException, XGBoostError;
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
*
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
Map<String, Integer> getFeatureScore() throws XGBoostError ;
|
||||
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for (int i = 0; i < dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
}
|
||||
return handles;
|
||||
}
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @param featureMap file to save dumped model info
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError;
|
||||
// making Booster serializable
|
||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||
try {
|
||||
out.writeObject(this.toByteArray());
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
}
|
||||
}
|
||||
|
||||
void dispose();
|
||||
private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
|
||||
try {
|
||||
this.init(null);
|
||||
byte[] bytes = (byte[])in.readObject();
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
super.finalize();
|
||||
dispose();
|
||||
}
|
||||
|
||||
public synchronized void dispose() {
|
||||
if (handle != 0L) {
|
||||
XGBoostJNI.XGBoosterFree(handle);
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Iterator;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
@ -56,7 +57,7 @@ class DataBatch {
|
||||
return b;
|
||||
}
|
||||
|
||||
static class BatchIterator implements Iterator<DataBatch> {
|
||||
static class BatchIterator implements Iterator<DataBatch>, Serializable {
|
||||
private Iterator<LabeledPoint> base;
|
||||
private int batchSize;
|
||||
|
||||
|
||||
@ -1,525 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Booster for xgboost, similar to the python wrapper xgboost.py
|
||||
* but custom obj function and eval function not supported at present.
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
class JavaBoosterImpl implements Booster {
|
||||
private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class);
|
||||
|
||||
long handle = 0;
|
||||
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* init Booster from dMatrixs
|
||||
*
|
||||
* @param params parameters
|
||||
* @param dMatrixs DMatrix array
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
JavaBoosterImpl(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
|
||||
init(dMatrixs);
|
||||
setParam("seed", "0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
/**
|
||||
* load model from modelPath
|
||||
*
|
||||
* @param params parameters
|
||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
JavaBoosterImpl(Map<String, Object> params, String modelPath) throws XGBoostError {
|
||||
init(null);
|
||||
if (modelPath == null) {
|
||||
throw new NullPointerException("modelPath : null");
|
||||
}
|
||||
loadModel(modelPath);
|
||||
setParam("seed", "0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
||||
long[] handles = null;
|
||||
if (dMatrixs != null) {
|
||||
handles = dmatrixsToHandles(dMatrixs);
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
|
||||
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameter
|
||||
*
|
||||
* @param key param name
|
||||
* @param value param value
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public final void setParam(String key, String value) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameters
|
||||
*
|
||||
* @param params parameters key-value map
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setParams(Map<String, Object> params) throws XGBoostError {
|
||||
if (params != null) {
|
||||
for (Map.Entry<String, Object> entry : params.entrySet()) {
|
||||
setParam(entry.getKey(), entry.getValue().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Update (one iteration)
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter current iteration number
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void update(DMatrix dtrain, int iter) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param obj customized objective class
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = predict(dtrain, true);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||
if (grad.length != hess.length) {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||
hess));
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrixs.
|
||||
*
|
||||
* @param evalMatrixs dmatrixs for evaluation
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @return eval information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
||||
long[] handles = dmatrixsToHandles(evalMatrixs);
|
||||
String[] evalInfo = new String[1];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||
evalInfo));
|
||||
return evalInfo[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given customized Evaluation class
|
||||
*
|
||||
* @param evalMatrixs evaluation matrix
|
||||
* @param evalNames evaluation names
|
||||
* @param eval custom evaluator
|
||||
* @return eval information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
|
||||
throws XGBoostError {
|
||||
String evalInfo = "";
|
||||
for (int i = 0; i < evalNames.length; i++) {
|
||||
String evalName = evalNames[i];
|
||||
DMatrix evalMat = evalMatrixs[i];
|
||||
float evalResult = eval.eval(predict(evalMat), evalMat);
|
||||
String evalMetric = eval.getMetric();
|
||||
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
|
||||
}
|
||||
return evalInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* base function for Predict
|
||||
*
|
||||
* @param data data
|
||||
* @param outPutMargin output margin
|
||||
* @param treeLimit limit number of trees
|
||||
* @param predLeaf prediction minimum to keep leafs
|
||||
* @return predict results
|
||||
*/
|
||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
|
||||
boolean predLeaf) throws XGBoostError {
|
||||
int optionMask = 0;
|
||||
if (outPutMargin) {
|
||||
optionMask = 1;
|
||||
}
|
||||
if (predLeaf) {
|
||||
optionMask = 2;
|
||||
}
|
||||
float[][] rawPredicts = new float[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||
treeLimit, rawPredicts));
|
||||
int row = (int) data.rowNum();
|
||||
int col = rawPredicts[0].length / row;
|
||||
float[][] predicts = new float[row][col];
|
||||
int r, c;
|
||||
for (int i = 0; i < rawPredicts[0].length; i++) {
|
||||
r = i / col;
|
||||
c = i % col;
|
||||
predicts[r][c] = rawPredicts[0][i];
|
||||
}
|
||||
return predicts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public float[][] predict(DMatrix data) throws XGBoostError {
|
||||
return pred(data, false, 0, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
|
||||
return pred(data, outPutMargin, 0, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
|
||||
return pred(data, outPutMargin, treeLimit, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index
|
||||
* of each sample in each tree.
|
||||
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
* in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
|
||||
return pred(data, false, treeLimit, predLeaf);
|
||||
}
|
||||
|
||||
/**
|
||||
* save model to modelPath
|
||||
*
|
||||
* @param modelPath model path
|
||||
*/
|
||||
public void saveModel(String modelPath) throws XGBoostError{
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
|
||||
}
|
||||
|
||||
private void loadModel(String modelPath) {
|
||||
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* get the dump of the model as a string array
|
||||
*
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get the dump of the model as a string array
|
||||
*
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
* @throws FileNotFoundException file not found
|
||||
* @throws UnsupportedEncodingException unsupported feature
|
||||
* @throws IOException error with model writing
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
|
||||
File tf = new File(modelPath);
|
||||
FileOutputStream out = new FileOutputStream(tf);
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||
String[] modelInfos = getDumpInfo(withStats);
|
||||
|
||||
for (int i = 0; i < modelInfos.length; i++) {
|
||||
writer.write("booster [" + i + "]:\n");
|
||||
writer.write(modelInfos[i]);
|
||||
}
|
||||
|
||||
writer.close();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
* @throws FileNotFoundException exception
|
||||
* @throws UnsupportedEncodingException exception
|
||||
* @throws IOException exception
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
|
||||
IOException, XGBoostError {
|
||||
File tf = new File(modelPath);
|
||||
FileOutputStream out = new FileOutputStream(tf);
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||
String[] modelInfos = getDumpInfo(featureMap, withStats);
|
||||
|
||||
for (int i = 0; i < modelInfos.length; i++) {
|
||||
writer.write("booster [" + i + "]:\n");
|
||||
writer.write(modelInfos[i]);
|
||||
}
|
||||
|
||||
writer.close();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore() throws XGBoostError {
|
||||
String[] modelInfos = getDumpInfo(false);
|
||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
||||
for (String tree : modelInfos) {
|
||||
for (String node : tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
if (array.length == 1) {
|
||||
continue;
|
||||
}
|
||||
String fid = array[1].split("\\]")[0];
|
||||
fid = fid.split("<")[0];
|
||||
if (featureScore.containsKey(fid)) {
|
||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||
} else {
|
||||
featureScore.put(fid, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @param featureMap file to save dumped model info
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
||||
String[] modelInfos = getDumpInfo(featureMap, false);
|
||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
||||
for (String tree : modelInfos) {
|
||||
for (String node : tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
if (array.length == 1) {
|
||||
continue;
|
||||
}
|
||||
String fid = array[1].split("\\]")[0];
|
||||
fid = fid.split("<")[0];
|
||||
if (featureScore.containsKey(fid)) {
|
||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||
} else {
|
||||
featureScore.put(fid, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as byte array representation.
|
||||
* Write these bytes to a file will give compatible format with other xgboost bindings.
|
||||
*
|
||||
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
|
||||
*
|
||||
* @return the saved byte array.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
byte[][] bytes = new byte[1][];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Load the booster model from thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @return the stored version number of the checkpoint.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the booster model into thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
*
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for (int i = 0; i < dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
}
|
||||
return handles;
|
||||
}
|
||||
|
||||
// making Booster serializable
|
||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||
try {
|
||||
out.writeObject(this.toByteArray());
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
}
|
||||
}
|
||||
|
||||
private void readObject(java.io.ObjectInputStream in)
|
||||
throws IOException, ClassNotFoundException {
|
||||
try {
|
||||
this.init(null);
|
||||
byte[] bytes = (byte[])in.readObject();
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
super.finalize();
|
||||
dispose();
|
||||
}
|
||||
|
||||
public synchronized void dispose() {
|
||||
if (handle != 0L) {
|
||||
XGBoostJNI.XGBoosterFree(handle);
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
@ -9,7 +10,7 @@ import org.apache.commons.logging.LogFactory;
|
||||
/**
|
||||
* Rabit global class for synchronization.
|
||||
*/
|
||||
public class Rabit {
|
||||
public class Rabit implements Serializable {
|
||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||
//load native library
|
||||
static {
|
||||
|
||||
@ -71,7 +71,7 @@ public class XGBoost {
|
||||
}
|
||||
|
||||
//initialize booster
|
||||
JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
|
||||
Booster booster = new Booster(params, allMats);
|
||||
|
||||
int version = booster.loadRabitCheckpoint();
|
||||
|
||||
@ -115,7 +115,7 @@ public class XGBoost {
|
||||
public static Booster initBoostingModel(
|
||||
Map<String, Object> params,
|
||||
DMatrix[] dMatrixs) throws XGBoostError {
|
||||
return new JavaBoosterImpl(params, dMatrixs);
|
||||
return new Booster(params, dMatrixs);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -127,7 +127,7 @@ public class XGBoost {
|
||||
*/
|
||||
public static Booster loadBoostModel(Map<String, Object> params, String modelPath)
|
||||
throws XGBoostError {
|
||||
return new JavaBoosterImpl(params, modelPath);
|
||||
return new Booster(params, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -16,172 +16,86 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.java
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
trait Booster extends Serializable {
|
||||
class Booster private[xgboost4j](booster: java.Booster) extends Serializable {
|
||||
|
||||
def setParam(key: String, value: String): Unit = {
|
||||
booster.setParam(key, value)
|
||||
}
|
||||
|
||||
def update(dtrain: DMatrix, iter: Int): Unit = {
|
||||
booster.update(dtrain.jDMatrix, iter)
|
||||
}
|
||||
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
}
|
||||
|
||||
def dumpModel(modelPath: String, withStats: Boolean): Unit = {
|
||||
booster.dumpModel(modelPath, withStats)
|
||||
}
|
||||
|
||||
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
|
||||
booster.dumpModel(modelPath, featureMap, withStats)
|
||||
}
|
||||
|
||||
def setParams(params: Map[String, AnyRef]): Unit = {
|
||||
booster.setParams(params.asJava)
|
||||
}
|
||||
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||
}
|
||||
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
|
||||
String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||
}
|
||||
|
||||
def dispose: Unit = {
|
||||
booster.dispose()
|
||||
}
|
||||
|
||||
def predict(data: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix)
|
||||
}
|
||||
|
||||
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin)
|
||||
}
|
||||
|
||||
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
||||
Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, treeLimit, predLeaf)
|
||||
}
|
||||
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
}
|
||||
|
||||
def getFeatureScore: mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore.asScala
|
||||
}
|
||||
|
||||
def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore(featureMap).asScala
|
||||
}
|
||||
|
||||
def saveModel(modelPath: String): Unit = {
|
||||
booster.saveModel(modelPath)
|
||||
}
|
||||
|
||||
override def finalize(): Unit = {
|
||||
super.finalize()
|
||||
dispose
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* set parameter
|
||||
*
|
||||
* @param key param name
|
||||
* @param value param value
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setParam(key: String, value: String)
|
||||
|
||||
/**
|
||||
* set parameters
|
||||
*
|
||||
* @param params parameters key-value map
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setParams(params: Map[String, AnyRef])
|
||||
|
||||
/**
|
||||
* Update (one iteration)
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter current iteration number
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, iter: Int)
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param obj customized objective class
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait)
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float])
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrixs.
|
||||
*
|
||||
* @param evalMatrixs dmatrixs for evaluation
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @return eval information
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String
|
||||
|
||||
/**
|
||||
* evaluate with given customized Evaluation class
|
||||
*
|
||||
* @param evalMatrixs evaluation matrix
|
||||
* @param evalNames evaluation names
|
||||
* @param eval custom evaluator
|
||||
* @return eval information
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @return predict result
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]]
|
||||
|
||||
/**
|
||||
* save model to modelPath
|
||||
*
|
||||
* @param modelPath model path
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def saveModel(modelPath: String)
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param withStats bool Controls whether the split statistics are output.
|
||||
*/
|
||||
@throws(classOf[IOException])
|
||||
@throws(classOf[XGBoostError])
|
||||
def dumpModel(modelPath: String, withStats: Boolean)
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
*
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
*/
|
||||
@throws(classOf[IOException])
|
||||
@throws(classOf[XGBoostError])
|
||||
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean)
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureScore: mutable.Map[String, Integer]
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
*
|
||||
* @param featureMap file to save dumped model info
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureScore(featureMap: String): mutable.Map[String, Integer]
|
||||
|
||||
def dispose
|
||||
}
|
||||
|
||||
@ -1,99 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster {
|
||||
|
||||
override def setParam(key: String, value: String): Unit = {
|
||||
booster.setParam(key, value)
|
||||
}
|
||||
|
||||
override def update(dtrain: DMatrix, iter: Int): Unit = {
|
||||
booster.update(dtrain.jDMatrix, iter)
|
||||
}
|
||||
|
||||
override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
}
|
||||
|
||||
override def dumpModel(modelPath: String, withStats: Boolean): Unit = {
|
||||
booster.dumpModel(modelPath, withStats)
|
||||
}
|
||||
|
||||
override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
|
||||
booster.dumpModel(modelPath, featureMap, withStats)
|
||||
}
|
||||
|
||||
override def setParams(params: Map[String, AnyRef]): Unit = {
|
||||
booster.setParams(params.asJava)
|
||||
}
|
||||
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||
}
|
||||
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
|
||||
String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||
}
|
||||
|
||||
override def dispose: Unit = {
|
||||
booster.dispose()
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
||||
Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, treeLimit, predLeaf)
|
||||
}
|
||||
|
||||
override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
}
|
||||
|
||||
override def getFeatureScore: mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore.asScala
|
||||
}
|
||||
|
||||
override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore(featureMap).asScala
|
||||
}
|
||||
|
||||
override def saveModel(modelPath: String): Unit = {
|
||||
booster.saveModel(modelPath)
|
||||
}
|
||||
|
||||
override def finalize(): Unit = {
|
||||
super.finalize()
|
||||
dispose
|
||||
}
|
||||
}
|
||||
@ -16,9 +16,10 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
|
||||
|
||||
object XGBoost {
|
||||
|
||||
def train(
|
||||
@ -31,7 +32,7 @@ object XGBoost {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
def crossValidation(
|
||||
@ -47,11 +48,11 @@ object XGBoost {
|
||||
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
|
||||
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user