add java wrapper
This commit is contained in:
438
java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java
Normal file
438
java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java
Normal file
@@ -0,0 +1,438 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStreamWriter;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.TransferUtil;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
|
||||
/**
|
||||
* Booster for xgboost, similar to the python wrapper xgboost.py
|
||||
* but custom obj function and eval function not supported at present.
|
||||
* @author hzx
|
||||
*/
|
||||
public final class Booster {
|
||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||
|
||||
long handle = 0;
|
||||
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
Initializer.InitXgboost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* init Booster from dMatrixs
|
||||
* @param params parameters
|
||||
* @param dMatrixs DMatrix array
|
||||
*/
|
||||
public Booster(Params params, DMatrix[] dMatrixs) {
|
||||
init(dMatrixs);
|
||||
setParam("seed","0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* load model from modelPath
|
||||
* @param params parameters
|
||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||
*/
|
||||
public Booster(Params params, String modelPath) {
|
||||
handle = XgboostJNI.XGBoosterCreate(new long[] {});
|
||||
loadModel(modelPath);
|
||||
setParam("seed","0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
private void init(DMatrix[] dMatrixs) {
|
||||
long[] handles = null;
|
||||
if(dMatrixs != null) {
|
||||
handles = TransferUtil.dMatrixs2handles(dMatrixs);
|
||||
}
|
||||
handle = XgboostJNI.XGBoosterCreate(handles);
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameter
|
||||
* @param key param name
|
||||
* @param value param value
|
||||
*/
|
||||
public final void setParam(String key, String value) {
|
||||
XgboostJNI.XGBoosterSetParam(handle, key, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameters
|
||||
* @param params parameters key-value map
|
||||
*/
|
||||
public void setParams(Params params) {
|
||||
if(params!=null) {
|
||||
for(Map.Entry<String, String> entry : params) {
|
||||
setParam(entry.getKey(), entry.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Update (one iteration)
|
||||
* @param dtrain training data
|
||||
* @param iter current iteration number
|
||||
*/
|
||||
public void update(DMatrix dtrain, int iter) {
|
||||
XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
* @param dtrain training data
|
||||
* @param iter current iteration number
|
||||
* @param obj customized objective class
|
||||
*/
|
||||
public void update(DMatrix dtrain, int iter, IObjective obj) {
|
||||
float[][] predicts = predict(dtrain, true);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
* @param dtrain training data
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
*/
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) {
|
||||
if(grad.length != hess.length) {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
|
||||
}
|
||||
XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess);
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrixs.
|
||||
* @param evalMatrixs dmatrixs for evaluation
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @return eval information
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
|
||||
long[] handles = TransferUtil.dMatrixs2handles(evalMatrixs);
|
||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
|
||||
return evalInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given customized Evaluation class
|
||||
* @param evalMatrixs
|
||||
* @param evalNames
|
||||
* @param iter
|
||||
* @param eval
|
||||
* @return eval information
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) {
|
||||
String evalInfo = "";
|
||||
for(int i=0; i<evalNames.length; i++) {
|
||||
String evalName = evalNames[i];
|
||||
DMatrix evalMat = evalMatrixs[i];
|
||||
float evalResult = eval.eval(predict(evalMat), evalMat);
|
||||
String evalMetric = eval.getMetric();
|
||||
evalInfo += String.format("\t%s-%s:%f", evalName,evalMetric, evalResult);
|
||||
}
|
||||
return evalInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrix handles;
|
||||
* @param dHandles evaluation data handles
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @return eval information
|
||||
*/
|
||||
public String evalSet(long[] dHandles, String[] evalNames, int iter) {
|
||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames);
|
||||
return evalInfo;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* evaluate with given dmatrix, similar to evalSet
|
||||
* @param evalMat
|
||||
* @param evalName
|
||||
* @param iter
|
||||
* @return eval information
|
||||
*/
|
||||
public String eval(DMatrix evalMat, String evalName, int iter) {
|
||||
DMatrix[] evalMats = new DMatrix[] {evalMat};
|
||||
String[] evalNames = new String[] {evalName};
|
||||
return evalSet(evalMats, evalNames, iter);
|
||||
}
|
||||
|
||||
/**
|
||||
* base function for Predict
|
||||
* @param data
|
||||
* @param outPutMargin
|
||||
* @param treeLimit
|
||||
* @param predLeaf
|
||||
* @return predict results
|
||||
*/
|
||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) {
|
||||
int optionMask = 0;
|
||||
if(outPutMargin) {
|
||||
optionMask = 1;
|
||||
}
|
||||
if(predLeaf) {
|
||||
optionMask = 2;
|
||||
}
|
||||
float[] rawPredicts = XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit);
|
||||
int row = (int) data.rowNum();
|
||||
int col = (int) rawPredicts.length/row;
|
||||
float[][] predicts = new float[row][col];
|
||||
int r,c;
|
||||
for(int i=0; i< rawPredicts.length; i++) {
|
||||
r = i/col;
|
||||
c = i%col;
|
||||
predicts[r][c] = rawPredicts[i];
|
||||
}
|
||||
return predicts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
* @param data dmatrix storing the input
|
||||
* @return predict result
|
||||
*/
|
||||
public float[][] predict(DMatrix data) {
|
||||
return pred(data, false, 0, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @return predict result
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin) {
|
||||
return pred(data, outPutMargin, 0, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @return predict result
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) {
|
||||
return pred(data, outPutMargin, treeLimit, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict with data
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
|
||||
with each record indicating the predicted leaf index of each sample in each tree.
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
*/
|
||||
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) {
|
||||
return pred(data, false, treeLimit, predLeaf);
|
||||
}
|
||||
|
||||
/**
|
||||
* save model to modelPath
|
||||
* @param modelPath
|
||||
*/
|
||||
public void saveModel(String modelPath) {
|
||||
XgboostJNI.XGBoosterSaveModel(handle, modelPath);
|
||||
}
|
||||
|
||||
private void loadModel(String modelPath) {
|
||||
XgboostJNI.XGBoosterLoadModel(handle, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* get the dump of the model as a string array
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
*/
|
||||
public String[] getDumpInfo(boolean withStats) {
|
||||
int statsFlag = 0;
|
||||
if(withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag);
|
||||
return modelInfos;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the dump of the model as a string array
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
*/
|
||||
public String[] getDumpInfo(String featureMap, boolean withStats) {
|
||||
int statsFlag = 0;
|
||||
if(withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag);
|
||||
return modelInfos;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param withStats bool
|
||||
Controls whether the split statistics are output.
|
||||
* @throws FileNotFoundException
|
||||
* @throws UnsupportedEncodingException
|
||||
* @throws IOException
|
||||
*/
|
||||
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
||||
File tf = new File(modelPath);
|
||||
FileOutputStream out = new FileOutputStream(tf);
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||
String[] modelInfos = getDumpInfo(withStats);
|
||||
|
||||
for(int i=0; i<modelInfos.length; i++) {
|
||||
writer.write("booster [" + i +"]:\n");
|
||||
writer.write(modelInfos[i]);
|
||||
}
|
||||
|
||||
writer.close();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Dump model into a text file.
|
||||
* @param modelPath file to save dumped model info
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats bool
|
||||
Controls whether the split statistics are output.
|
||||
* @throws FileNotFoundException
|
||||
* @throws UnsupportedEncodingException
|
||||
* @throws IOException
|
||||
*/
|
||||
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
||||
File tf = new File(modelPath);
|
||||
FileOutputStream out = new FileOutputStream(tf);
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||
String[] modelInfos = getDumpInfo(featureMap, withStats);
|
||||
|
||||
for(int i=0; i<modelInfos.length; i++) {
|
||||
writer.write("booster [" + i +"]:\n");
|
||||
writer.write(modelInfos[i]);
|
||||
}
|
||||
|
||||
writer.close();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore() {
|
||||
String[] modelInfos = getDumpInfo(false);
|
||||
Map<String, Integer> featureScore = new HashMap<>();
|
||||
for(String tree : modelInfos) {
|
||||
for(String node : tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
if(array.length == 1) {
|
||||
continue;
|
||||
}
|
||||
String fid = array[1].split("\\]")[0];
|
||||
fid = fid.split("<")[0];
|
||||
if(featureScore.containsKey(fid)) {
|
||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||
}
|
||||
else {
|
||||
featureScore.put(fid, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
* @param featureMap file to save dumped model info
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String featureMap) {
|
||||
String[] modelInfos = getDumpInfo(featureMap, false);
|
||||
Map<String, Integer> featureScore = new HashMap<>();
|
||||
for(String tree : modelInfos) {
|
||||
for(String node : tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
if(array.length == 1) {
|
||||
continue;
|
||||
}
|
||||
String fid = array[1].split("\\]")[0];
|
||||
fid = fid.split("<")[0];
|
||||
if(featureScore.containsKey(fid)) {
|
||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||
}
|
||||
else {
|
||||
featureScore.put(fid, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if(handle != 0l) {
|
||||
XgboostJNI.XGBoosterFree(handle);
|
||||
handle=0;
|
||||
}
|
||||
}
|
||||
}
|
||||
217
java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java
Normal file
217
java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java
Normal file
@@ -0,0 +1,217 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.util.TransferUtil;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
/**
|
||||
* DMatrix for xgboost, similar to the python wrapper xgboost.py
|
||||
* @author hzx
|
||||
*/
|
||||
public class DMatrix {
|
||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||
long handle = 0;
|
||||
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
Initializer.InitXgboost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* sparse matrix type (CSR or CSC)
|
||||
*/
|
||||
public static enum SparseType {
|
||||
CSR,
|
||||
CSC;
|
||||
}
|
||||
|
||||
/**
|
||||
* init DMatrix from file (svmlight format)
|
||||
* @param dataPath
|
||||
*/
|
||||
public DMatrix(String dataPath) {
|
||||
handle = XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from sparse matrix
|
||||
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
*/
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) {
|
||||
if(st == SparseType.CSR) {
|
||||
handle = XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data);
|
||||
}
|
||||
else if(st == SparseType.CSC) {
|
||||
handle = XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data);
|
||||
}
|
||||
else {
|
||||
throw new UnknownError("unknow sparsetype");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
*/
|
||||
public DMatrix(float[] data, int nrow, int ncol) {
|
||||
handle = XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f);
|
||||
}
|
||||
|
||||
/**
|
||||
* used for DMatrix slice
|
||||
* @param handle
|
||||
*/
|
||||
private DMatrix(long handle) {
|
||||
this.handle = handle;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
* @param labels
|
||||
*/
|
||||
public void setLabel(float[] labels) {
|
||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels);
|
||||
}
|
||||
|
||||
/**
|
||||
* set weight of each instance
|
||||
* @param weights
|
||||
*/
|
||||
public void setWeight(float[] weights) {
|
||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights);
|
||||
}
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
* @param baseMargin
|
||||
*/
|
||||
public void setBaseMargin(float[] baseMargin) {
|
||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin);
|
||||
}
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
* @param baseMargin
|
||||
*/
|
||||
public void setBaseMargin(float[][] baseMargin) {
|
||||
float[] flattenMargin = TransferUtil.flatten(baseMargin);
|
||||
setBaseMargin(flattenMargin);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set group sizes of DMatrix (used for ranking)
|
||||
* @param group
|
||||
*/
|
||||
public void setGroup(int[] group) {
|
||||
XgboostJNI.XGDMatrixSetGroup(handle, group);
|
||||
}
|
||||
|
||||
private float[] getFloatInfo(String field) {
|
||||
float[] infos = XgboostJNI.XGDMatrixGetFloatInfo(handle, field);
|
||||
return infos;
|
||||
}
|
||||
|
||||
private int[] getIntInfo(String field) {
|
||||
int[] infos = XgboostJNI.XGDMatrixGetUIntInfo(handle, field);
|
||||
return infos;
|
||||
}
|
||||
|
||||
/**
|
||||
* get label values
|
||||
* @return label
|
||||
*/
|
||||
public float[] getLabel() {
|
||||
return getFloatInfo("label");
|
||||
}
|
||||
|
||||
/**
|
||||
* get weight of the DMatrix
|
||||
* @return weights
|
||||
*/
|
||||
public float[] getWeight() {
|
||||
return getFloatInfo("weight");
|
||||
}
|
||||
|
||||
/**
|
||||
* get base margin of the DMatrix
|
||||
* @return base margin
|
||||
*/
|
||||
public float[] getBaseMargin() {
|
||||
return getFloatInfo("base_margin");
|
||||
}
|
||||
|
||||
/**
|
||||
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
||||
* @param rowIndex
|
||||
* @return sliced new DMatrix
|
||||
*/
|
||||
public DMatrix slice(int[] rowIndex) {
|
||||
long sHandle = XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex);
|
||||
DMatrix sMatrix = new DMatrix(sHandle);
|
||||
return sMatrix;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the row number of DMatrix
|
||||
* @return number of rows
|
||||
*/
|
||||
public long rowNum() {
|
||||
return XgboostJNI.XGDMatrixNumRow(handle);
|
||||
}
|
||||
|
||||
/**
|
||||
* save DMatrix to filePath
|
||||
* @param filePath
|
||||
*/
|
||||
public void saveBinary(String filePath) {
|
||||
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
||||
}
|
||||
|
||||
public long getHandle() {
|
||||
return handle;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
delete();
|
||||
}
|
||||
|
||||
public synchronized void delete() {
|
||||
if(handle != 0) {
|
||||
XgboostJNI.XGDMatrixFree(handle);
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j;
|
||||
|
||||
/**
|
||||
* interface for customized evaluation
|
||||
* @author hzx
|
||||
*/
|
||||
public interface IEvaluation {
|
||||
/**
|
||||
* get evaluate metric
|
||||
* @return evalMetric
|
||||
*/
|
||||
public abstract String getMetric();
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
* @param predicts
|
||||
* @param dmat
|
||||
* @return
|
||||
*/
|
||||
public abstract float eval(float[][] predicts, DMatrix dmat);
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* interface for customize Object function
|
||||
* @author hzx
|
||||
*/
|
||||
public interface IObjective {
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
* @param predicts untransformed margin predicts
|
||||
* @param dtrain training data
|
||||
* @return List with two float array, correspond to first order grad and second order grad
|
||||
*/
|
||||
public abstract List<float[]> getGradient(float[][] predicts, DMatrix dtrain);
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.IObjective;
|
||||
|
||||
/**
|
||||
* cross validation package for xgb
|
||||
* @author hzx
|
||||
*/
|
||||
public class CVPack {
|
||||
DMatrix dtrain;
|
||||
DMatrix dtest;
|
||||
DMatrix[] dmats;
|
||||
long[] dataArray;
|
||||
String[] names;
|
||||
Booster booster;
|
||||
|
||||
/**
|
||||
* create an cross validation package
|
||||
* @param dtrain train data
|
||||
* @param dtest test data
|
||||
* @param params parameters
|
||||
*/
|
||||
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
|
||||
dmats = new DMatrix[] {dtrain, dtest};
|
||||
booster = new Booster(params, dmats);
|
||||
dataArray = TransferUtil.dMatrixs2handles(dmats);
|
||||
names = new String[] {"train", "test"};
|
||||
this.dtrain = dtrain;
|
||||
this.dtest = dtest;
|
||||
}
|
||||
|
||||
/**
|
||||
* update one iteration
|
||||
* @param iter iteration num
|
||||
*/
|
||||
public void update(int iter) {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
|
||||
/**
|
||||
* update one iteration
|
||||
* @param iter iteration num
|
||||
* @param obj customized objective
|
||||
*/
|
||||
public void update(int iter, IObjective obj) {
|
||||
booster.update(dtrain, iter, obj);
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluation
|
||||
* @param iter iteration num
|
||||
* @return
|
||||
*/
|
||||
public String eval(int iter) {
|
||||
return booster.evalSet(dataArray, names, iter);
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluation
|
||||
* @param iter iteration num
|
||||
* @param eval customized eval
|
||||
* @return
|
||||
*/
|
||||
public String eval(int iter, IEvaluation eval) {
|
||||
return booster.evalSet(dmats, names, iter, eval);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* class to load native library
|
||||
* @author hzx
|
||||
*/
|
||||
public class Initializer {
|
||||
private static final Log logger = LogFactory.getLog(Initializer.class);
|
||||
|
||||
static boolean initialized = false;
|
||||
public static final String nativePath = "./lib";
|
||||
public static final String nativeResourcePath = "/lib/";
|
||||
public static final String[] libNames = new String[] {"xgboostjavawrapper"};
|
||||
|
||||
public static synchronized void InitXgboost() throws IOException {
|
||||
if(initialized == false) {
|
||||
for(String libName: libNames) {
|
||||
smartLoad(libName);
|
||||
}
|
||||
initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* load native library, this method will first try to load library from java.library.path, then try to load from library in jar package.
|
||||
* @param libName
|
||||
* @throws IOException
|
||||
*/
|
||||
private static void smartLoad(String libName) throws IOException {
|
||||
addNativeDir(nativePath);
|
||||
try {
|
||||
System.loadLibrary(libName);
|
||||
}
|
||||
catch (UnsatisfiedLinkError e) {
|
||||
try {
|
||||
NativeUtils.loadLibraryFromJar(nativeResourcePath + System.mapLibraryName(libName));
|
||||
}
|
||||
catch (IOException e1) {
|
||||
throw e1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* add libPath to java.library.path, then native library in libPath would be load properly
|
||||
* @param libPath
|
||||
* @throws IOException
|
||||
*/
|
||||
public static void addNativeDir(String libPath) throws IOException {
|
||||
try {
|
||||
Field field = ClassLoader.class.getDeclaredField("usr_paths");
|
||||
field.setAccessible(true);
|
||||
String[] paths = (String[]) field.get(null);
|
||||
for (String path : paths) {
|
||||
if (libPath.equals(path)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
String[] tmp = new String[paths.length+1];
|
||||
System.arraycopy(paths,0,tmp,0,paths.length);
|
||||
tmp[paths.length] = libPath;
|
||||
field.set(null, tmp);
|
||||
} catch (IllegalAccessException e) {
|
||||
logger.error(e.getMessage());
|
||||
throw new IOException("Failed to get permissions to set library path");
|
||||
} catch (NoSuchFieldException e) {
|
||||
logger.error(e.getMessage());
|
||||
throw new IOException("Failed to get field handle to set library path");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
/*
|
||||
* To change this license header, choose License Headers in Project Properties.
|
||||
* To change this template file, choose Tools | Templates
|
||||
* and open the template in the editor.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
|
||||
/**
|
||||
* Simple library class for working with JNI (Java Native Interface)
|
||||
*
|
||||
* @see http://adamheinrich.com/2012/how-to-load-native-jni-library-from-jar
|
||||
*
|
||||
* @author Adam Heirnich <adam@adamh.cz>, http://www.adamh.cz
|
||||
*/
|
||||
public class NativeUtils {
|
||||
|
||||
/**
|
||||
* Private constructor - this class will never be instanced
|
||||
*/
|
||||
private NativeUtils() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads library from current JAR archive
|
||||
*
|
||||
* The file from JAR is copied into system temporary directory and then loaded. The temporary file is deleted after exiting.
|
||||
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
||||
*
|
||||
* @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext
|
||||
* @throws IOException If temporary file creation or read/write operation fails
|
||||
* @throws IllegalArgumentException If source file (param path) does not exist
|
||||
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters (restriction of {@see File#createTempFile(java.lang.String, java.lang.String)}).
|
||||
*/
|
||||
public static void loadLibraryFromJar(String path) throws IOException {
|
||||
|
||||
if (!path.startsWith("/")) {
|
||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||
}
|
||||
|
||||
// Obtain filename from path
|
||||
String[] parts = path.split("/");
|
||||
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
|
||||
|
||||
// Split filename to prexif and suffix (extension)
|
||||
String prefix = "";
|
||||
String suffix = null;
|
||||
if (filename != null) {
|
||||
parts = filename.split("\\.", 2);
|
||||
prefix = parts[0];
|
||||
suffix = (parts.length > 1) ? "."+parts[parts.length - 1] : null; // Thanks, davs! :-)
|
||||
}
|
||||
|
||||
// Check if the filename is okay
|
||||
if (filename == null || prefix.length() < 3) {
|
||||
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
|
||||
}
|
||||
|
||||
// Prepare temporary file
|
||||
File temp = File.createTempFile(prefix, suffix);
|
||||
temp.deleteOnExit();
|
||||
|
||||
if (!temp.exists()) {
|
||||
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
|
||||
}
|
||||
|
||||
// Prepare buffer for data copying
|
||||
byte[] buffer = new byte[1024];
|
||||
int readBytes;
|
||||
|
||||
// Open and check input stream
|
||||
InputStream is = NativeUtils.class.getResourceAsStream(path);
|
||||
if (is == null) {
|
||||
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
|
||||
}
|
||||
|
||||
// Open output stream and copy data between source file in JAR and the temporary file
|
||||
OutputStream os = new FileOutputStream(temp);
|
||||
try {
|
||||
while ((readBytes = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, readBytes);
|
||||
}
|
||||
} finally {
|
||||
// If read/write fails, close streams safely before throwing an exception
|
||||
os.close();
|
||||
is.close();
|
||||
}
|
||||
|
||||
// Finally, load the library
|
||||
System.load(temp.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.AbstractMap;
|
||||
|
||||
|
||||
/**
|
||||
* a util class for handle params
|
||||
* @author hzx
|
||||
*/
|
||||
public class Params implements Iterable<Entry<String, String>>{
|
||||
List<Entry<String, String>> params = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* put param key-value pair
|
||||
* @param key
|
||||
* @param value
|
||||
*/
|
||||
public void put(String key, String value) {
|
||||
params.add(new AbstractMap.SimpleEntry<>(key, value));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
String paramsInfo = "";
|
||||
for(Entry<String, String> param : params) {
|
||||
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
|
||||
}
|
||||
return paramsInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Entry<String, String>> iterator() {
|
||||
return params.iterator();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.IObjective;
|
||||
|
||||
|
||||
/**
|
||||
* trainer for xgboost
|
||||
* @author hzx
|
||||
*/
|
||||
public class Trainer {
|
||||
private static final Log logger = LogFactory.getLog(Trainer.class);
|
||||
|
||||
/**
|
||||
* Train a booster with given parameters.
|
||||
* @param params Booster params.
|
||||
* @param dtrain Data to be trained.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param evalMats Data to be evaluated (may include dtrain)
|
||||
* @param evalNames name of data (used for evaluation info)
|
||||
* @param obj customized objective (set to null if not used)
|
||||
* @param eval customized evaluation (set to null if not used)
|
||||
* @return trained booster
|
||||
*/
|
||||
public static Booster train(Params params, DMatrix dtrain, int round,
|
||||
DMatrix[] evalMats, String[] evalNames, IObjective obj, IEvaluation eval) {
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
if(evalMats!=null && evalMats.length>0) {
|
||||
allMats = new DMatrix[evalMats.length+1];
|
||||
allMats[0] = dtrain;
|
||||
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
|
||||
}
|
||||
else {
|
||||
allMats = new DMatrix[1];
|
||||
allMats[0] = dtrain;
|
||||
}
|
||||
|
||||
//initialize booster
|
||||
Booster booster = new Booster(params, allMats);
|
||||
|
||||
//used for evaluation
|
||||
long[] dataArray = null;
|
||||
String[] names = null;
|
||||
|
||||
if(dataArray==null || names==null) {
|
||||
//prepare data for evaluation
|
||||
dataArray = TransferUtil.dMatrixs2handles(evalMats);
|
||||
names = evalNames;
|
||||
}
|
||||
|
||||
//begin to train
|
||||
for(int iter=0; iter<round; iter++) {
|
||||
if(obj != null) {
|
||||
booster.update(dtrain, iter, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
|
||||
//evaluation
|
||||
if(evalMats!=null && evalMats.length>0) {
|
||||
String evalInfo;
|
||||
if(eval != null) {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
|
||||
}
|
||||
else {
|
||||
evalInfo = booster.evalSet(dataArray, names, iter);
|
||||
}
|
||||
logger.info(evalInfo);
|
||||
}
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
/**
|
||||
* Cross-validation with given paramaters.
|
||||
* @param params Booster params.
|
||||
* @param data Data to be trained.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param nfold Number of folds in CV.
|
||||
* @param metrics Evaluation metrics to be watched in CV.
|
||||
* @param obj customized objective (set to null if not used)
|
||||
* @param eval customized evaluation (set to null if not used)
|
||||
* @return evaluation history
|
||||
*/
|
||||
public static String[] crossValiation(Params params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
|
||||
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
|
||||
String[] evalHist = new String[round];
|
||||
String[] results = new String[cvPacks.length];
|
||||
for(int i=0; i<round; i++) {
|
||||
for(CVPack cvPack : cvPacks) {
|
||||
if(obj != null) {
|
||||
cvPack.update(i, obj);
|
||||
}
|
||||
else {
|
||||
cvPack.update(i);
|
||||
}
|
||||
}
|
||||
|
||||
for(int j=0; j<cvPacks.length; j++) {
|
||||
if(eval != null) {
|
||||
results[j] = cvPacks[j].eval(i, eval);
|
||||
}
|
||||
else {
|
||||
results[j] = cvPacks[j].eval(i);
|
||||
}
|
||||
}
|
||||
|
||||
evalHist[i] = aggCVResults(results);
|
||||
logger.info(evalHist[i]);
|
||||
}
|
||||
return evalHist;
|
||||
}
|
||||
|
||||
/**
|
||||
* make an n-fold array of CVPack from random indices
|
||||
* @param data original data
|
||||
* @param nfold num of folds
|
||||
* @param params booster parameters
|
||||
* @param evalMetrics Evaluation metrics
|
||||
* @return CV package array
|
||||
*/
|
||||
public static CVPack[] makeNFold(DMatrix data, int nfold, Params params, String[] evalMetrics) {
|
||||
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
|
||||
int step = samples.size()/nfold;
|
||||
int[] testSlice = new int[step];
|
||||
int[] trainSlice = new int[samples.size()-step];
|
||||
int testid, trainid;
|
||||
CVPack[] cvPacks = new CVPack[nfold];
|
||||
for(int i=0; i<nfold; i++) {
|
||||
testid = 0;
|
||||
trainid = 0;
|
||||
for(int j=0; j<samples.size(); j++) {
|
||||
if(j>(i*step) && j<(i*step+step) && testid<step) {
|
||||
testSlice[testid] = samples.get(j);
|
||||
testid++;
|
||||
}
|
||||
else{
|
||||
if(trainid<samples.size()-step) {
|
||||
trainSlice[trainid] = samples.get(j);
|
||||
trainid++;
|
||||
}
|
||||
else {
|
||||
testSlice[testid] = samples.get(j);
|
||||
testid++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DMatrix dtrain = data.slice(trainSlice);
|
||||
DMatrix dtest = data.slice(testSlice);
|
||||
CVPack cvPack = new CVPack(dtrain, dtest, params);
|
||||
//set eval types
|
||||
if(evalMetrics!=null) {
|
||||
for(String type : evalMetrics) {
|
||||
cvPack.booster.setParam("eval_metric", type);
|
||||
}
|
||||
}
|
||||
cvPacks[i] = cvPack;
|
||||
}
|
||||
|
||||
return cvPacks;
|
||||
}
|
||||
|
||||
private static List<Integer> genRandPermutationNums(int start, int end) {
|
||||
List<Integer> samples = new ArrayList<>();
|
||||
for(int i=start; i<end; i++) {
|
||||
samples.add(i);
|
||||
}
|
||||
Collections.shuffle(samples);
|
||||
return samples;
|
||||
}
|
||||
|
||||
/**
|
||||
* Aggregate cross-validation results.
|
||||
* @param results eval info from each data sample
|
||||
* @return cross-validation eval info
|
||||
*/
|
||||
public static String aggCVResults(String[] results) {
|
||||
Map<String, List<Float> > cvMap = new HashMap<>();
|
||||
String aggResult = results[0].split("\t")[0];
|
||||
for(String result : results) {
|
||||
String[] items = result.split("\t");
|
||||
for(int i=1; i<items.length; i++) {
|
||||
String[] tup = items[i].split(":");
|
||||
String key = tup[0];
|
||||
Float value = Float.valueOf(tup[1]);
|
||||
if(!cvMap.containsKey(key)) {
|
||||
cvMap.put(key, new ArrayList<Float>());
|
||||
}
|
||||
cvMap.get(key).add(value);
|
||||
}
|
||||
}
|
||||
|
||||
for(String key : cvMap.keySet()) {
|
||||
float value = 0f;
|
||||
for(Float tvalue : cvMap.get(key)) {
|
||||
value += tvalue;
|
||||
}
|
||||
value /= cvMap.get(key).size();
|
||||
aggResult += String.format("\tcv-%s:%f", key, value);
|
||||
}
|
||||
|
||||
return aggResult;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class TransferUtil {
|
||||
/**
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
public static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for(int i=0; i<dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
}
|
||||
return handles;
|
||||
}
|
||||
|
||||
/**
|
||||
* flatten a mat to array
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
public static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
for (float[] array : mat) size += array.length;
|
||||
float[] result = new float[size];
|
||||
int pos = 0;
|
||||
for (float[] ar : mat) {
|
||||
System.arraycopy(ar, 0, result, pos, ar.length);
|
||||
pos += ar.length;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.wrapper;
|
||||
|
||||
/**
|
||||
* xgboost jni wrapper functions for xgboost_wrapper.h
|
||||
* @author hzx
|
||||
*/
|
||||
public class XgboostJNI {
|
||||
public final static native long XGDMatrixCreateFromFile(String fname, int silent);
|
||||
public final static native long XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data);
|
||||
public final static native long XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data);
|
||||
public final static native long XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing);
|
||||
public final static native long XGDMatrixSliceDMatrix(long handle, int[] idxset);
|
||||
public final static native void XGDMatrixFree(long handle);
|
||||
public final static native void XGDMatrixSaveBinary(long handle, String fname, int silent);
|
||||
public final static native void XGDMatrixSetFloatInfo(long handle, String field, float[] array);
|
||||
public final static native void XGDMatrixSetUIntInfo(long handle, String field, int[] array);
|
||||
public final static native void XGDMatrixSetGroup(long handle, int[] group);
|
||||
public final static native float[] XGDMatrixGetFloatInfo(long handle, String field);
|
||||
public final static native int[] XGDMatrixGetUIntInfo(long handle, String filed);
|
||||
public final static native long XGDMatrixNumRow(long handle);
|
||||
public final static native long XGBoosterCreate(long[] handles);
|
||||
public final static native void XGBoosterFree(long handle);
|
||||
public final static native void XGBoosterSetParam(long handle, String name, String value);
|
||||
public final static native void XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
||||
public final static native void XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
|
||||
public final static native String XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames);
|
||||
public final static native float[] XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit);
|
||||
public final static native void XGBoosterLoadModel(long handle, String fname);
|
||||
public final static native void XGBoosterSaveModel(long handle, String fname);
|
||||
public final static native void XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
|
||||
public final static native String XGBoosterGetModelRaw(long handle);
|
||||
public final static native String[] XGBoosterDumpModel(long handle, String fmap, int with_stats);
|
||||
}
|
||||
1
java/xgboost4j/src/main/resources/lib/README.md
Normal file
1
java/xgboost4j/src/main/resources/lib/README.md
Normal file
@@ -0,0 +1 @@
|
||||
please put native library in this package.
|
||||
Reference in New Issue
Block a user