Merge pull request #890 from CodingCat/jvm_package

[WIP] refactor xgboost4j to create jvm-packages
This commit is contained in:
Tianqi Chen 2016-03-01 21:01:01 -08:00
commit 7d9457d72f
69 changed files with 4155 additions and 2956 deletions

9
.gitignore vendored
View File

@ -70,3 +70,12 @@ config.mk
xgboost xgboost
*.data *.data
build_plugin build_plugin
dmlc-core
.idea
recommonmark/
tags
*.iml
*.class
target
*.swp

View File

@ -84,7 +84,7 @@ $(DMLC_CORE)/libdmlc.a:
$(RABIT)/lib/$(LIB_RABIT): $(RABIT)/lib/$(LIB_RABIT):
+ cd $(RABIT); make lib/$(LIB_RABIT); cd $(ROOTDIR) + cd $(RABIT); make lib/$(LIB_RABIT); cd $(ROOTDIR)
java: java/libxgboost4j.so jvm: jvm-packages/lib/libxgboost4j.so
SRC = $(wildcard src/*.cc src/*/*.cc) SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) $(PLUGIN_OBJS) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) $(PLUGIN_OBJS)
@ -120,7 +120,8 @@ lib/libxgboost.dll lib/libxgboost.so: $(ALL_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %a, $^) $(LDFLAGS)
java/libxgboost4j.so: java/xgboost4j_wrapper.cpp $(ALL_DEP) jvm-packages/lib/libxgboost4j.so: jvm-packages/xgboost4j/src/native/xgboost4j.cpp $(ALL_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(JAVAINCFLAGS) -shared -o $@ $(filter %.cpp %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) $(JAVAINCFLAGS) -shared -o $@ $(filter %.cpp %.o %.a, $^) $(LDFLAGS)
xgboost: $(CLI_OBJ) $(ALL_DEP) xgboost: $(CLI_OBJ) $(ALL_DEP)

View File

@ -1,36 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j-demo</artifactId>
<version>1.0</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>1.1</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.4</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -1,164 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.AbstractMap;
import java.util.AbstractMap.SimpleEntry;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.DataLoader;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* a simple example of java wrapper for xgboost
* @author hzx
*/
public class BasicWalkThrough {
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
if(fPredicts.length != sPredicts.length) {
return false;
}
for(int i=0; i<fPredicts.length; i++) {
if(!Arrays.equals(fPredicts[i], sPredicts[i])) {
return false;
}
}
return true;
}
public static void main(String[] args) throws UnsupportedEncodingException, IOException, XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
//note: any Iterable<Entry<String, Object>> object would be used as paramters
//e.g.
// Map<String, Object> paramMap = new HashMap<String, Object>() {
// {
// put("eta", 1.0);
// put("max_depth", 2);
// put("silent", 1);
// put("objective", "binary:logistic");
// }
// };
// Iterable<Entry<String, Object>> param = paramMap.entrySet();
//or
// List<Entry<String, Object>> param = new ArrayList<Entry<String, Object>>() {
// {
// add(new SimpleEntry<String, Object>("eta", 1.0));
// add(new SimpleEntry<String, Object>("max_depth", 2.0));
// add(new SimpleEntry<String, Object>("silent", 1));
// add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
// }
// };
//we use a util class Params to handle parameters as example
Iterable<Entry<String, Object>> param = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//specify watchList to set evaluation dmats
//note: any Iterable<Entry<String, DMatrix>> object would be used as watchList
//e.g.
//an entrySet of Map is good
// Map<String, DMatrix> watchMap = new HashMap<>();
// watchMap.put("train", trainMat);
// watchMap.put("test", testMat);
// Iterable<Entry<String, DMatrix>> watchs = watchMap.entrySet();
//we use a List of Entry<String, DMatrix> WatchList as example
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new SimpleEntry<>("train", trainMat));
watchs.add(new SimpleEntry<>("test", testMat));
//set round
int round = 2;
//train a boost model
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
//predict
float[][] predicts = booster.predict(testMat);
//save model to modelPath
File file = new File("./model");
if(!file.exists()) {
file.mkdirs();
}
String modelPath = "./model/xgb.model";
booster.saveModel(modelPath);
//dump model
booster.dumpModel("./model/dump.raw.txt", false);
//dump model with feature map
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
//save dmatrix into binary buffer
testMat.saveBinary("./model/dtest.buffer");
//reload model and data
Booster booster2 = new Booster(param, "./model/xgb.model");
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
float[][] predicts2 = booster2.predict(testMat2);
//check the two predicts
System.out.println(checkPredicts(predicts, predicts2));
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR);
trainMat2.setLabel(spData.labels);
//specify watchList
List<Entry<String, DMatrix>> watchs2 = new ArrayList<>();
watchs2.add(new SimpleEntry<>("train", trainMat2));
watchs2.add(new SimpleEntry<>("test", testMat2));
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
float[][] predicts3 = booster3.predict(testMat2);
//check predicts
System.out.println(checkPredicts(predicts, predicts3));
}
}

View File

@ -1,67 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* example for start from a initial base prediction
* @author hzx
*/
public class BoostFromPrediction {
public static void main(String[] args) throws XGBoostError {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//specify watchList
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train xgboost for 1 round
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);
float[][] trainPred = booster.predict(trainMat, true);
float[][] testPred = booster.predict(testMat, true);
trainMat.setBaseMargin(trainPred);
testMat.setBaseMargin(testPred);
System.out.println("result of running from initial prediction");
Booster booster2 = Trainer.train(param, trainMat, 1, watchs, null, null);
}
}

View File

@ -1,54 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.io.IOException;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* an example of cross validation
* @author hzx
*/
public class CrossValidation {
public static void main(String[] args) throws IOException, XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//set params
Params param = new Params() {
{
put("eta", 1.0);
put("max_depth", 3);
put("silent", 1);
put("nthread", 6);
put("objective", "binary:logistic");
put("gamma", 1.0);
put("eval_metric", "error");
}
};
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = Trainer.crossValiation(param, trainMat, round, nfold, metrics, null, null);
}
}

View File

@ -1,175 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* an example user define objective and eval
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* he buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
* @author hzx
*/
public class CustomObjective {
/**
* loglikelihoode loss obj function
*/
public static class LogRegObj implements IObjective {
private static final Log logger = LogFactory.getLog(LogRegObj.class);
/**
* simple sigmoid func
* @param input
* @return
* Note: this func is not concern about numerical stability, only used as example
*/
public float sigmoid(float input) {
float val = (float) (1/(1+Math.exp(-input)));
return val;
}
public float[][] transform(float[][] predicts) {
int nrow = predicts.length;
float[][] transPredicts = new float[nrow][1];
for(int i=0; i<nrow; i++) {
transPredicts[i][0] = sigmoid(predicts[i][0]);
}
return transPredicts;
}
@Override
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
int nrow = predicts.length;
List<float[]> gradients = new ArrayList<>();
float[] labels;
try {
labels = dtrain.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return null;
}
float[] grad = new float[nrow];
float[] hess = new float[nrow];
float[][] transPredicts = transform(predicts);
for(int i=0; i<nrow; i++) {
float predict = transPredicts[i][0];
grad[i] = predict - labels[i];
hess[i] = predict * (1 - predict);
}
gradients.add(grad);
gradients.add(hess);
return gradients;
}
}
/**
* user defined eval function.
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* the buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
*/
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0) {
error++;
}
}
return error/labels.length;
}
}
public static void main(String[] args) throws XGBoostError {
//load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//load valid mat (svmlight format)
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//set params
//set params
Params param = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
}
};
//set round
int round = 2;
//specify watchList
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//user define obj and eval
IObjective obj = new LogRegObj();
IEvaluation eval = new EvalError();
//train a booster
System.out.println("begin to train the booster model");
Booster booster = Trainer.train(param, trainMat, round, watchs, obj, eval);
}
}

View File

@ -1,65 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* simple example for using external memory version
* @author hzx
*/
public class ExternalMemory {
public static void main(String[] args) throws XGBoostError {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
//specify parameters
Params param = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//performance notice: set nthread to be the number of your real cpu
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
//param.put("nthread", num_real_cpu);
//specify watchList
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//set round
int round = 2;
//train a boost model
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
}
}

View File

@ -1,74 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* this is an example of fit generalized linear model in xgboost
* basically, we are using linear model, instead of tree for our boosters
* @author hzx
*/
public class GeneralizedLinearModel {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model
// alpha is the L1 regularizer
//lambda is the L2 regularizer
//you can also set lambda_bias which is L2 regularizer on the bias term
Params param = new Params() {
{
put("alpha", 0.0001);
put("silent", 1);
put("objective", "binary:logistic");
put("booster", "gblinear");
}
};
//normally, you do not need to set eta (step_size)
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
//there could be affection on convergence with parallelization on certain cases
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
//param.put("eta", "0.5");
//specify watchList
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 4;
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
float[][] predicts = booster.predict(testMat);
CustomEval eval = new CustomEval();
System.out.println("error=" + eval.eval(predicts, testMat));
}
}

View File

@ -1,69 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* predict first ntree
* @author hzx
*/
public class PredictFirstNtree {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//specify watchList
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 3;
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
//predict use 1 tree
float[][] predicts1 = booster.predict(testMat, false, 1);
//by default all trees are used to do predict
float[][] predicts2 = booster.predict(testMat);
//use a simple evaluation class to check error result
CustomEval eval = new CustomEval();
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
}
}

View File

@ -1,70 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* predict leaf indices
* @author hzx
*/
public class PredictLeafIndices {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//specify watchList
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 3;
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
//predict using first 2 tree
float[][] leafindex = booster.predict(testMat, 2, true);
for(float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
//predict all trees
leafindex = booster.predict(testMat, 0, true);
for(float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
}
}

View File

@ -1,60 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.util.XGBoostError;
/**
* a util evaluation class for examples
* @author hzx
*/
public class CustomEval implements IEvaluation {
private static final Log logger = LogFactory.getLog(CustomEval.class);
String evalMetric = "custom_error";
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0.5) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0.5) {
error++;
}
}
return error/labels.length;
}
}

View File

@ -1,127 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
/**
* util class for loading data
* @author hzx
*/
public class DataLoader {
public static class DenseData {
public float[] labels;
public float[] data;
public int nrow;
public int ncol;
}
public static class CSRSparseData {
public float[] labels;
public float[] data;
public long[] rowHeaders;
public int[] colIndex;
}
public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
DenseData denseData = new DenseData();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
denseData.nrow = 0;
denseData.ncol = -1;
String line;
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
while((line=reader.readLine()) != null) {
String[] items = line.trim().split(",");
if(items.length==0) {
continue;
}
denseData.nrow++;
if(denseData.ncol == -1) {
denseData.ncol = items.length - 1;
}
tlabels.add(Float.valueOf(items[items.length-1]));
for(int i=0; i<items.length-1; i++) {
tdata.add(Float.valueOf(items[i]));
}
}
reader.close();
in.close();
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
return denseData;
}
public static CSRSparseData loadSVMFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
CSRSparseData spData = new CSRSparseData();
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
List<Long> theaders = new ArrayList<>();
List<Integer> tindex = new ArrayList<>();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
String line;
long rowheader = 0;
theaders.add(rowheader);
while((line=reader.readLine()) != null) {
String[] items = line.trim().split(" ");
if(items.length==0) {
continue;
}
rowheader += items.length - 1;
theaders.add(rowheader);
tlabels.add(Float.valueOf(items[0]));
for(int i=1; i<items.length; i++) {
String[] tup = items[i].split(":");
assert tup.length == 2;
tdata.add(Float.valueOf(tup[1]));
tindex.add(Integer.valueOf(tup[0]));
}
}
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
return spData;
}
}

View File

@ -1,54 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.AbstractMap;
/**
* a util class for handle params
* @author hzx
*/
public class Params implements Iterable<Entry<String, Object>>{
List<Entry<String, Object>> params = new ArrayList<>();
/**
* put param key-value pair
* @param key
* @param value
*/
public void put(String key, Object value) {
params.add(new AbstractMap.SimpleEntry<>(key, value));
}
@Override
public String toString(){
String paramsInfo = "";
for(Entry<String, Object> param : params) {
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
}
return paramsInfo;
}
@Override
public Iterator<Entry<String, Object>> iterator() {
return params.iterator();
}
}

View File

@ -1,35 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>1.1</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<reporting>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.10.3</version>
</plugin>
</plugins>
</reporting>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
<version>1.2</version>
</dependency>
</dependencies>
</project>

View File

@ -1,484 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.ErrorHandle;
import org.dmlc.xgboost4j.util.XGBoostError;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
/**
* Booster for xgboost, similar to the python wrapper xgboost.py
* but custom obj function and eval function not supported at present.
* @author hzx
*/
public final class Booster {
private static final Log logger = LogFactory.getLog(Booster.class);
long handle = 0;
//load native library
static {
try {
Initializer.InitXgboost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* init Booster from dMatrixs
* @param params parameters
* @param dMatrixs DMatrix array
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) throws XGBoostError {
init(dMatrixs);
setParam("seed","0");
setParams(params);
}
/**
* load model from modelPath
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XGBoostError {
init(null);
if(modelPath == null) {
throw new NullPointerException("modelPath : null");
}
loadModel(modelPath);
setParam("seed","0");
setParams(params);
}
private void init(DMatrix[] dMatrixs) throws XGBoostError {
long[] handles = null;
if(dMatrixs != null) {
handles = dMatrixs2handles(dMatrixs);
}
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
/**
* set parameter
* @param key param name
* @param value param value
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public final void setParam(String key, String value) throws XGBoostError {
ErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
}
/**
* set parameters
* @param params parameters key-value map
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void setParams(Iterable<Entry<String, Object>> params) throws XGBoostError {
if(params!=null) {
for(Map.Entry<String, Object> entry : params) {
setParam(entry.getKey(), entry.getValue().toString());
}
}
}
/**
* Update (one iteration)
* @param dtrain training data
* @param iter current iteration number
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void update(DMatrix dtrain, int iter) throws XGBoostError {
ErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
* update with customize obj func
* @param dtrain training data
* @param iter current iteration number
* @param obj customized objective class
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError {
float[][] predicts = predict(dtrain, true);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* update with give grad and hess
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
if(grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
}
ErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
}
/**
* evaluate with given dmatrixs.
* @param evalMatrixs dmatrixs for evaluation
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dMatrixs2handles(evalMatrixs);
String[] evalInfo = new String[1];
ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
return evalInfo[0];
}
/**
* evaluate with given customized Evaluation class
* @param evalMatrixs evaluation matrix
* @param evalNames evaluation names
* @param iter number of interations
* @param eval custom evaluator
* @return eval information
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) throws XGBoostError {
String evalInfo = "";
for(int i=0; i<evalNames.length; i++) {
String evalName = evalNames[i];
DMatrix evalMat = evalMatrixs[i];
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName,evalMetric, evalResult);
}
return evalInfo;
}
/**
* evaluate with given dmatrix handles;
* @param dHandles evaluation data handles
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public String evalSet(long[] dHandles, String[] evalNames, int iter) throws XGBoostError {
String[] evalInfo = new String[1];
ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames, evalInfo));
return evalInfo[0];
}
/**
* evaluate with given dmatrix, similar to evalSet
* @param evalMat evaluation matrix
* @param evalName evaluation name
* @param iter number of iterations
* @return eval information
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public String eval(DMatrix evalMat, String evalName, int iter) throws XGBoostError {
DMatrix[] evalMats = new DMatrix[] {evalMat};
String[] evalNames = new String[] {evalName};
return evalSet(evalMats, evalNames, iter);
}
/**
* base function for Predict
* @param data data
* @param outPutMargin output margin
* @param treeLimit limit number of trees
* @param predLeaf prediction minimum to keep leafs
* @return predict results
*/
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit, boolean predLeaf) throws XGBoostError {
int optionMask = 0;
if(outPutMargin) {
optionMask = 1;
}
if(predLeaf) {
optionMask = 2;
}
float[][] rawPredicts = new float[1][];
ErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = (int) rawPredicts[0].length/row;
float[][] predicts = new float[row][col];
int r,c;
for(int i=0; i< rawPredicts[0].length; i++) {
r = i/col;
c = i%col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}
/**
* Predict with data
* @param data dmatrix storing the input
* @return predict result
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public float[][] predict(DMatrix data) throws XGBoostError {
return pred(data, false, 0, false);
}
/**
* Predict with data
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
return pred(data, outPutMargin, 0, false);
}
/**
* Predict with data
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
return pred(data, outPutMargin, treeLimit, false);
}
/**
* Predict with data
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
with each record indicating the predicted leaf index of each sample in each tree.
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0.
* @return predict result
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public float[][] predict(DMatrix data , int treeLimit, boolean predLeaf) throws XGBoostError {
return pred(data, false, treeLimit, predLeaf);
}
/**
* save model to modelPath
* @param modelPath model path
*/
public void saveModel(String modelPath) {
XgboostJNI.XGBoosterSaveModel(handle, modelPath);
}
private void loadModel(String modelPath) {
XgboostJNI.XGBoosterLoadModel(handle, modelPath);
}
/**
* get the dump of the model as a string array
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if(withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
/**
* get the dump of the model as a string array
* @param featureMap featureMap file
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
int statsFlag = 0;
if(withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
return modelInfos[0];
}
/**
* Dump model into a text file.
* @param modelPath file to save dumped model info
* @param withStats bool
Controls whether the split statistics are output.
* @throws FileNotFoundException file not found
* @throws UnsupportedEncodingException unsupported feature
* @throws IOException error with model writing
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(withStats);
for(int i=0; i<modelInfos.length; i++) {
writer.write("booster [" + i +"]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* Dump model into a text file.
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
Controls whether the split statistics are output.
* @throws FileNotFoundException exception
* @throws UnsupportedEncodingException exception
* @throws IOException exception
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(featureMap, withStats);
for(int i=0; i<modelInfos.length; i++) {
writer.write("booster [" + i +"]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* get importance of each feature
* @return featureMap key: feature index, value: feature importance score
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public Map<String, Integer> getFeatureScore() throws XGBoostError {
String[] modelInfos = getDumpInfo(false);
Map<String, Integer> featureScore = new HashMap<>();
for(String tree : modelInfos) {
for(String node : tree.split("\n")) {
String[] array = node.split("\\[");
if(array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if(featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
}
else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* get importance of each feature
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getDumpInfo(featureMap, false);
Map<String, Integer> featureScore = new HashMap<>();
for(String tree : modelInfos) {
for(String node : tree.split("\n")) {
String[] array = node.split("\\[");
if(array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if(featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
}
else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* transfer DMatrix array to handle array (used for native functions)
* @param dmatrixs
* @return handle array for input dmatrixs
*/
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for(int i=0; i<dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
@Override
protected void finalize() {
delete();
}
public synchronized void delete() {
if(handle != 0l) {
XgboostJNI.XGBoosterFree(handle);
handle=0;
}
}
}

View File

@ -1,268 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.ErrorHandle;
import org.dmlc.xgboost4j.util.XGBoostError;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
/**
* DMatrix for xgboost, similar to the python wrapper xgboost.py
* @author hzx
*/
public class DMatrix {
private static final Log logger = LogFactory.getLog(DMatrix.class);
long handle = 0;
//load native library
static {
try {
Initializer.InitXgboost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* sparse matrix type (CSR or CSC)
*/
public static enum SparseType {
CSR,
CSC;
}
/**
* init DMatrix from file (svmlight format)
* @param dataPath path of data file
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public DMatrix(String dataPath) throws XGBoostError {
if(dataPath == null) {
throw new NullPointerException("dataPath: null");
}
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
handle = out[0];
}
/**
* create DMatrix from sparse matrix
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
* @param data non zero values (sequence by row for CSR or by col for CSC)
* @param st sparse matrix type (CSR or CSC)
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if(st == SparseType.CSR) {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
}
else if(st == SparseType.CSC) {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
}
else {
throw new UnknownError("unknow sparsetype");
}
handle = out[0];
}
/**
* create DMatrix from dense matrix
* @param data data values
* @param nrow number of rows
* @param ncol number of columns
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
handle = out[0];
}
/**
* used for DMatrix slice
* @param handle
*/
private DMatrix(long handle) {
this.handle = handle;
}
/**
* set label of dmatrix
* @param labels labels
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void setLabel(float[] labels) throws XGBoostError {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
}
/**
* set weight of each instance
* @param weights weights
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void setWeight(float[] weights) throws XGBoostError {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
* @param baseMargin base margin
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
* @param baseMargin base margin
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
float[] flattenMargin = flatten(baseMargin);
setBaseMargin(flattenMargin);
}
/**
* Set group sizes of DMatrix (used for ranking)
* @param group group size as array
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void setGroup(int[] group) throws XGBoostError {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
}
private float[] getFloatInfo(String field) throws XGBoostError {
float[][] infos = new float[1][];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
return infos[0];
}
private int[] getIntInfo(String field) throws XGBoostError {
int[][] infos = new int[1][];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
return infos[0];
}
/**
* get label values
* @return label
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public float[] getLabel() throws XGBoostError {
return getFloatInfo("label");
}
/**
* get weight of the DMatrix
* @return weights
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public float[] getWeight() throws XGBoostError {
return getFloatInfo("weight");
}
/**
* get base margin of the DMatrix
* @return base margin
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public float[] getBaseMargin() throws XGBoostError {
return getFloatInfo("base_margin");
}
/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
* @param rowIndex row index
* @return sliced new DMatrix
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public DMatrix slice(int[] rowIndex) throws XGBoostError {
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
long sHandle = out[0];
DMatrix sMatrix = new DMatrix(sHandle);
return sMatrix;
}
/**
* get the row number of DMatrix
* @return number of rows
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public long rowNum() throws XGBoostError {
long[] rowNum = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle,rowNum));
return rowNum[0];
}
/**
* save DMatrix to filePath
* @param filePath file path
*/
public void saveBinary(String filePath) {
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
}
/**
* Get the handle
* @return native handler id
*/
public long getHandle() {
return handle;
}
/**
* flatten a mat to array
* @param mat
* @return
*/
private static float[] flatten(float[][] mat) {
int size = 0;
for (float[] array : mat) size += array.length;
float[] result = new float[size];
int pos = 0;
for (float[] ar : mat) {
System.arraycopy(ar, 0, result, pos, ar.length);
pos += ar.length;
}
return result;
}
@Override
protected void finalize() {
delete();
}
public synchronized void delete() {
if(handle != 0) {
XgboostJNI.XGDMatrixFree(handle);
handle = 0;
}
}
}

View File

@ -1,89 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.Map;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
/**
* cross validation package for xgb
* @author hzx
*/
public class CVPack {
DMatrix dtrain;
DMatrix dtest;
DMatrix[] dmats;
String[] names;
Booster booster;
/**
* create an cross validation package
* @param dtrain train data
* @param dtest test data
* @param params parameters
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) throws XGBoostError {
dmats = new DMatrix[] {dtrain, dtest};
booster = new Booster(params, dmats);
names = new String[] {"train", "test"};
this.dtrain = dtrain;
this.dtest = dtest;
}
/**
* update one iteration
* @param iter iteration num
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void update(int iter) throws XGBoostError {
booster.update(dtrain, iter);
}
/**
* update one iteration
* @param iter iteration num
* @param obj customized objective
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public void update(int iter, IObjective obj) throws XGBoostError {
booster.update(dtrain, iter, obj);
}
/**
* evaluation
* @param iter iteration num
* @return evaluation
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public String eval(int iter) throws XGBoostError {
return booster.evalSet(dmats, names, iter);
}
/**
* evaluation
* @param iter iteration num
* @param eval customized eval
* @return evaluation
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public String eval(int iter, IEvaluation eval) throws XGBoostError {
return booster.evalSet(dmats, names, iter, eval);
}
}

View File

@ -1,49 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
/**
* Error handle for Xgboost.
*/
public class ErrorHandle {
private static final Log logger = LogFactory.getLog(ErrorHandle.class);
//load native library
static {
try {
Initializer.InitXgboost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* Check the return value of C API.
* @param ret return valud of xgboostJNI C API call
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public static void checkCall(int ret) throws XGBoostError {
if(ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
}
}
}

View File

@ -1,92 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.io.IOException;
import java.lang.reflect.Field;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* class to load native library
* @author hzx
*/
public class Initializer {
private static final Log logger = LogFactory.getLog(Initializer.class);
static boolean initialized = false;
public static final String nativePath = "./lib";
public static final String nativeResourcePath = "/lib/";
public static final String[] libNames = new String[] {"xgboost4j"};
public static synchronized void InitXgboost() throws IOException {
if(initialized == false) {
for(String libName: libNames) {
smartLoad(libName);
}
initialized = true;
}
}
/**
* load native library, this method will first try to load library from java.library.path, then try to load library in jar package.
* @param libName library path
* @throws IOException exception
*/
private static void smartLoad(String libName) throws IOException {
addNativeDir(nativePath);
try {
System.loadLibrary(libName);
}
catch (UnsatisfiedLinkError e) {
try {
NativeUtils.loadLibraryFromJar(nativeResourcePath + System.mapLibraryName(libName));
}
catch (IOException e1) {
throw e1;
}
}
}
/**
* Add libPath to java.library.path, then native library in libPath would be load properly
* @param libPath library path
* @throws IOException exception
*/
public static void addNativeDir(String libPath) throws IOException {
try {
Field field = ClassLoader.class.getDeclaredField("usr_paths");
field.setAccessible(true);
String[] paths = (String[]) field.get(null);
for (String path : paths) {
if (libPath.equals(path)) {
return;
}
}
String[] tmp = new String[paths.length+1];
System.arraycopy(paths,0,tmp,0,paths.length);
tmp[paths.length] = libPath;
field.set(null, tmp);
} catch (IllegalAccessException e) {
logger.error(e.getMessage());
throw new IOException("Failed to get permissions to set library path");
} catch (NoSuchFieldException e) {
logger.error(e.getMessage());
throw new IOException("Failed to get field handle to set library path");
}
}
}

View File

@ -1,113 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
/**
* Simple library class for working with JNI (Java Native Interface)
* <p>
* See <a href="http://adamheinrich.com/2012/how-to-load-native-jni-library-from-jar">
* http://adamheinrich.com/2012/how-to-load-native-jni-library-from-jar</a>
* <p>
* Author Adam Heirnich &lt;adam@adamh.cz&gt;, http://www.adamh.cz
*/
public class NativeUtils {
/**
* Private constructor - this class will never be instanced
*/
private NativeUtils() {
}
/**
* Loads library from current JAR archive
* <p>
* The file from JAR is copied into system temporary directory and then loaded.
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to {@code path}.
*
* @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext
* @throws IOException If temporary file creation or read/write operation fails
* @throws IllegalArgumentException If source file (param path) does not exist
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters
*/
public static void loadLibraryFromJar(String path) throws IOException {
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}
// Obtain filename from path
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "."+parts[parts.length - 1] : null; // Thanks, davs! :-)
}
// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}
// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();
if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}
// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;
// Open and check input stream
InputStream is = NativeUtils.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}
// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
// Finally, load the library
System.load(temp.getAbsolutePath());
}
}

View File

@ -1,238 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
/**
* trainer for xgboost
* @author hzx
*/
public class Trainer {
private static final Log logger = LogFactory.getLog(Trainer.class);
/**
* Train a booster with given parameters.
* @param params Booster params.
* @param dtrain Data to be trained.
* @param round Number of boosting iterations.
* @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set.
* @param obj customized objective (set to null if not used)
* @param eval customized evaluation (set to null if not used)
* @return trained booster
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
List<String> names = new ArrayList<>();
List<DMatrix> mats = new ArrayList<>();
for(Entry<String, DMatrix> evalEntry : watchs) {
names.add(evalEntry.getKey());
mats.add(evalEntry.getValue());
}
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
//collect all data matrixs
DMatrix[] allMats;
if(evalMats!=null && evalMats.length>0) {
allMats = new DMatrix[evalMats.length+1];
allMats[0] = dtrain;
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
}
else {
allMats = new DMatrix[1];
allMats[0] = dtrain;
}
//initialize booster
Booster booster = new Booster(params, allMats);
//begin to train
for(int iter=0; iter<round; iter++) {
if(obj != null) {
booster.update(dtrain, iter, obj);
} else {
booster.update(dtrain, iter);
}
//evaluation
if(evalMats!=null && evalMats.length>0) {
String evalInfo;
if(eval != null) {
evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
}
else {
evalInfo = booster.evalSet(evalMats, evalNames, iter);
}
logger.info(evalInfo);
}
}
return booster;
}
/**
* Cross-validation with given paramaters.
* @param params Booster params.
* @param data Data to be trained.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
* @param obj customized objective (set to null if not used)
* @param eval customized evaluation (set to null if not used)
* @return evaluation history
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XGBoostError {
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
String[] evalHist = new String[round];
String[] results = new String[cvPacks.length];
for(int i=0; i<round; i++) {
for(CVPack cvPack : cvPacks) {
if(obj != null) {
cvPack.update(i, obj);
}
else {
cvPack.update(i);
}
}
for(int j=0; j<cvPacks.length; j++) {
if(eval != null) {
results[j] = cvPacks[j].eval(i, eval);
}
else {
results[j] = cvPacks[j].eval(i);
}
}
evalHist[i] = aggCVResults(results);
logger.info(evalHist[i]);
}
return evalHist;
}
/**
* make an n-fold array of CVPack from random indices
* @param data original data
* @param nfold num of folds
* @param params booster parameters
* @param evalMetrics Evaluation metrics
* @return CV package array
* @throws org.dmlc.xgboost4j.util.XGBoostError native error
*/
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) throws XGBoostError {
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
int step = samples.size()/nfold;
int[] testSlice = new int[step];
int[] trainSlice = new int[samples.size()-step];
int testid, trainid;
CVPack[] cvPacks = new CVPack[nfold];
for(int i=0; i<nfold; i++) {
testid = 0;
trainid = 0;
for(int j=0; j<samples.size(); j++) {
if(j>(i*step) && j<(i*step+step) && testid<step) {
testSlice[testid] = samples.get(j);
testid++;
}
else{
if(trainid<samples.size()-step) {
trainSlice[trainid] = samples.get(j);
trainid++;
}
else {
testSlice[testid] = samples.get(j);
testid++;
}
}
}
DMatrix dtrain = data.slice(trainSlice);
DMatrix dtest = data.slice(testSlice);
CVPack cvPack = new CVPack(dtrain, dtest, params);
//set eval types
if(evalMetrics!=null) {
for(String type : evalMetrics) {
cvPack.booster.setParam("eval_metric", type);
}
}
cvPacks[i] = cvPack;
}
return cvPacks;
}
private static List<Integer> genRandPermutationNums(int start, int end) {
List<Integer> samples = new ArrayList<>();
for(int i=start; i<end; i++) {
samples.add(i);
}
Collections.shuffle(samples);
return samples;
}
/**
* Aggregate cross-validation results.
* @param results eval info from each data sample
* @return cross-validation eval info
*/
public static String aggCVResults(String[] results) {
Map<String, List<Float> > cvMap = new HashMap<>();
String aggResult = results[0].split("\t")[0];
for(String result : results) {
String[] items = result.split("\t");
for(int i=1; i<items.length; i++) {
String[] tup = items[i].split(":");
String key = tup[0];
Float value = Float.valueOf(tup[1]);
if(!cvMap.containsKey(key)) {
cvMap.put(key, new ArrayList<Float>());
}
cvMap.get(key).add(value);
}
}
for(String key : cvMap.keySet()) {
float value = 0f;
for(Float tvalue : cvMap.get(key)) {
value += tvalue;
}
value /= cvMap.get(key).size();
aggResult += String.format("\tcv-%s:%f", key, value);
}
return aggResult;
}
}

View File

@ -1,142 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import junit.framework.TestCase;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XGBoostError;
import org.junit.Test;
/**
* test cases for Booster
* @author hzx
*/
public class BoosterTest {
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0) {
error++;
}
}
return error/labels.length;
}
}
@Test
public void testBoosterBasic() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//set params
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
Iterable<Entry<String, Object>> param = paramMap.entrySet();
//set watchList
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//set round
int round = 2;
//train a boost model
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
//predict raw output
float[][] predicts = booster.predict(testMat, true);
//eval
IEvaluation eval = new EvalError();
//error must be less than 0.1
TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f);
//test dump model
}
/**
* test cross valiation
* @throws XGBoostError
*/
@Test
public void testCV() throws XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//set params
Map<String, Object> param= new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 3);
put("silent", 1);
put("nthread", 6);
put("objective", "binary:logistic");
put("gamma", 1.0);
put("eval_metric", "error");
}
};
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = Trainer.crossValiation(param.entrySet(), trainMat, round, nfold, metrics, null, null);
}
}

View File

@ -1,102 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
import org.dmlc.xgboost4j.util.XGBoostError;
import org.junit.Test;
/**
* test cases for DMatrix
* @author hzx
*/
public class DMatrixTest {
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
//get label
float[] labels = dmat.getLabel();
//check length
TestCase.assertTrue(dmat.rowNum()==labels.length);
//set weights
float[] weights = Arrays.copyOf(labels, labels.length);
dmat.setWeight(weights);
float[] dweights = dmat.getWeight();
TestCase.assertTrue(Arrays.equals(weights, dweights));
}
@Test
public void testCreateFromCSR() throws XGBoostError {
//create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
float[] data = new float[] {1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
int[] colIndex = new int[] {0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
long[] rowHeaders = new long[] {0, 3, 7, 11};
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
//check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum()==3);
//test set label
float[] label1 = new float[] {1, 0, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromDenseMatrix() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data0 = new float[nrow*ncol];
//put random nums
Random random = new Random();
for(int i=0; i<nrow*ncol; i++) {
data0[i] = random.nextFloat();
}
//create label
float[] label0 = new float[nrow];
for(int i=0; i<nrow; i++) {
label0[i] = random.nextFloat();
}
DMatrix dmat0 = new DMatrix(data0, nrow, ncol);
dmat0.setLabel(label0);
//check
TestCase.assertTrue(dmat0.rowNum()==10);
TestCase.assertTrue(dmat0.getLabel().length==10);
//set weights for each instance
float[] weights = new float[nrow];
for(int i=0; i<nrow; i++) {
weights[i] = random.nextFloat();
}
dmat0.setWeight(weights);
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
}
}

View File

@ -1,221 +0,0 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_dmlc_xgboost4j_wrapper_XgboostJNI */
#ifndef _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI
#define _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBGetLastError
* Signature: ()Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError
(JNIEnv *, jclass);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromFile
* Signature: (Ljava/lang/String;I[J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *, jclass, jstring, jint, jlongArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromCSR
* Signature: ([J[I[F[J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromCSC
* Signature: ([J[I[F[J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromMat
* Signature: ([FIIF[J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSliceDMatrix
* Signature: (J[I[J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *, jclass, jlong, jintArray, jlongArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
(JNIEnv *, jclass, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *, jclass, jlong, jstring, jint);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *, jclass, jlong, jstring, jfloatArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jintArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;[[F)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;[[I)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixNumRow
* Signature: (J[J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
(JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterCreate
* Signature: ([J[J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
(JNIEnv *, jclass, jlongArray, jlongArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
(JNIEnv *, jclass, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterUpdateOneIter
* Signature: (JIJ)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *, jclass, jlong, jint, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterPredict
* Signature: (JJIJ[[F)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *, jclass, jlong, jlong, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterGetModelRaw
* Signature: (J[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv *, jclass, jlong, jobjectArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -0,0 +1,33 @@
<!--
~ Licensed to the Apache Software Foundation (ASF) under one or more
~ contributor license agreements. See the NOTICE file distributed with
~ this work for additional information regarding copyright ownership.
~ The ASF licenses this file to You under the Apache License, Version 2.0
~ (the "License"); you may not use this file except in compliance with
~ the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<!DOCTYPE suppressions PUBLIC
"-//Puppy Crawl//DTD Suppressions 1.1//EN"
"http://www.puppycrawl.com/dtds/suppressions_1_1.dtd">
<!--
This file contains suppression rules for Checkstyle checks.
Ideally only files that cannot be modified (e.g. third-party code)
should be added here. All other violations should be fixed.
-->
<suppressions>
<suppress checks=".*"
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java"/>
</suppressions>

169
jvm-packages/checkstyle.xml Normal file
View File

@ -0,0 +1,169 @@
<!--
~ Licensed to the Apache Software Foundation (ASF) under one or more
~ contributor license agreements. See the NOTICE file distributed with
~ this work for additional information regarding copyright ownership.
~ The ASF licenses this file to You under the Apache License, Version 2.0
~ (the "License"); you may not use this file except in compliance with
~ the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<!DOCTYPE module PUBLIC
"-//Puppy Crawl//DTD Check Configuration 1.3//EN"
"http://www.puppycrawl.com/dtds/configuration_1_3.dtd">
<!--
Checkstyle configuration based on the Google coding conventions from:
- Google Java Style
https://google-styleguide.googlecode.com/svn-history/r130/trunk/javaguide.html
with Spark-specific changes from:
https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide
Checkstyle is very configurable. Be sure to read the documentation at
http://checkstyle.sf.net (or in your downloaded distribution).
Most Checks are configurable, be sure to consult the documentation.
To completely disable a check, just comment it out or delete it from the file.
Authors: Max Vetrenko, Ruslan Diachenko, Roman Ivanov.
-->
<module name = "Checker">
<property name="charset" value="UTF-8"/>
<property name="severity" value="error"/>
<property name="fileExtensions" value="java, properties, xml"/>
<module name="SuppressionFilter">
<property name="file" value="checkstyle-suppressions.xml"/>
</module>
<!-- Checks for whitespace -->
<!-- See http://checkstyle.sf.net/config_whitespace.html -->
<module name="FileTabCharacter">
<property name="eachLine" value="true"/>
</module>
<module name="RegexpSingleline">
<!-- \s matches whitespace character, $ matches end of line. -->
<property name="format" value="\s+$"/>
<property name="message" value="No trailing whitespace allowed."/>
</module>
<module name="TreeWalker">
<module name="OuterTypeFilename"/>
<module name="IllegalTokenText">
<property name="tokens" value="STRING_LITERAL, CHAR_LITERAL"/>
<property name="format" value="\\u00(08|09|0(a|A)|0(c|C)|0(d|D)|22|27|5(C|c))|\\(0(10|11|12|14|15|42|47)|134)"/>
<property name="message" value="Avoid using corresponding octal or Unicode escape."/>
</module>
<module name="AvoidEscapedUnicodeCharacters">
<property name="allowEscapesForControlCharacters" value="true"/>
<property name="allowByTailComment" value="true"/>
<property name="allowNonPrintableEscapes" value="true"/>
</module>
<!-- TODO: 11/09/15 disabled - the lengths are currently > 100 in many places -->
<module name="LineLength">
<property name="max" value="100"/>
<property name="ignorePattern" value="^package.*|^import.*|a href|href|http://|https://|ftp://"/>
</module>
<module name="NoLineWrap"/>
<module name="EmptyBlock">
<property name="option" value="TEXT"/>
<property name="tokens" value="LITERAL_TRY, LITERAL_FINALLY, LITERAL_IF, LITERAL_ELSE, LITERAL_SWITCH"/>
</module>
<module name="NeedBraces">
<property name="allowSingleLineStatement" value="true"/>
</module>
<module name="OneStatementPerLine"/>
<module name="ArrayTypeStyle"/>
<module name="FallThrough"/>
<module name="UpperEll"/>
<module name="ModifierOrder"/>
<module name="SeparatorWrap">
<property name="tokens" value="DOT"/>
<property name="option" value="nl"/>
</module>
<module name="SeparatorWrap">
<property name="tokens" value="COMMA"/>
<property name="option" value="EOL"/>
</module>
<module name="PackageName">
<property name="format" value="^[a-z]+(\.[a-z][a-z0-9]*)*$"/>
<message key="name.invalidPattern"
value="Package name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="ClassTypeParameterName">
<property name="format" value="([A-Z][a-zA-Z0-9]*$)"/>
<message key="name.invalidPattern"
value="Class type name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="MethodTypeParameterName">
<property name="format" value="([A-Z][a-zA-Z0-9]*)"/>
<message key="name.invalidPattern"
value="Method type name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="GenericWhitespace">
<message key="ws.followed"
value="GenericWhitespace ''{0}'' is followed by whitespace."/>
<message key="ws.preceded"
value="GenericWhitespace ''{0}'' is preceded with whitespace."/>
<message key="ws.illegalFollow"
value="GenericWhitespace ''{0}'' should followed by whitespace."/>
<message key="ws.notPreceded"
value="GenericWhitespace ''{0}'' is not preceded with whitespace."/>
</module>
<!-- TODO: 11/09/15 disabled - indentation is currently inconsistent -->
<!--
<module name="Indentation">
<property name="basicOffset" value="4"/>
<property name="braceAdjustment" value="0"/>
<property name="caseIndent" value="4"/>
<property name="throwsIndent" value="4"/>
<property name="lineWrappingIndentation" value="4"/>
<property name="arrayInitIndent" value="4"/>
</module>
-->
<!-- TODO: 11/09/15 disabled - order is currently wrong in many places -->
<!--
<module name="ImportOrder">
<property name="separated" value="true"/>
<property name="ordered" value="true"/>
<property name="groups" value="/^javax?\./,scala,*,org.apache.spark"/>
</module>
-->
<module name="MethodParamPad"/>
<module name="AnnotationLocation">
<property name="tokens" value="CLASS_DEF, INTERFACE_DEF, ENUM_DEF, METHOD_DEF, CTOR_DEF"/>
</module>
<module name="AnnotationLocation">
<property name="tokens" value="VARIABLE_DEF"/>
<property name="allowSamelineMultipleAnnotations" value="true"/>
</module>
<module name="MethodName">
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9_]*$"/>
<message key="name.invalidPattern"
value="Method name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="EmptyCatchBlock">
<property name="exceptionVariableName" value="expected"/>
</module>
<module name="CommentsIndentation"/>
</module>
</module>

View File

@ -16,8 +16,8 @@ if [ $(uname) == "Darwin" ]; then
fi fi
cd .. cd ..
make java no_omp=${dis_omp} make jvm no_omp=${dis_omp}
cd java cd jvm-packages
echo "move native lib" echo "move native lib"
libPath="xgboost4j/src/main/resources/lib" libPath="xgboost4j/src/main/resources/lib"
@ -26,7 +26,7 @@ if [ ! -d "$libPath" ]; then
fi fi
rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl} rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
mv libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl} mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
popd > /dev/null popd > /dev/null
echo "complete" echo "complete"

117
jvm-packages/pom.xml Normal file
View File

@ -0,0 +1,117 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
<packaging>pom</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
<maven.version>3.3.9</maven.version>
<scala.version>2.11.7</scala.version>
<scala.binary.version>2.11</scala.binary.version>
</properties>
<modules>
<module>xgboost4j</module>
<module>xgboost4j-demo</module>
</modules>
<build>
<plugins>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
<version>0.8.0</version>
<configuration>
<verbose>false</verbose>
<failOnViolation>true</failOnViolation>
<includeTestSourceDirectory>true</includeTestSourceDirectory>
<sourceDirectory>${basedir}/src/main/scala</sourceDirectory>
<testSourceDirectory>${basedir}/src/test/scala</testSourceDirectory>
<configLocation>scalastyle-config.xml</configLocation>
<outputEncoding>UTF-8</outputEncoding>
</configuration>
<executions>
<execution>
<id>checkstyle</id>
<phase>validate</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>2.17</version>
<configuration>
<configLocation>checkstyle.xml</configLocation>
<failOnViolation>true</failOnViolation>
</configuration>
<executions>
<execution>
<id>checkstyle</id>
<phase>validate</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.2.2</version>
<executions>
<execution>
<id>compile</id>
<goals>
<goal>compile</goal>
</goals>
<phase>compile</phase>
</execution>
<execution>
<id>test-compile</id>
<goals>
<goal>testCompile</goal>
</goals>
<phase>test-compile</phase>
</execution>
<execution>
<phase>process-resources</phase>
<goals>
<goal>compile</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.19.1</version>
<configuration>
<argLine>-Djava.library.path=lib/</argLine>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
<version>1.2</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>2.2.6</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,291 @@
<!--
~ Licensed to the Apache Software Foundation (ASF) under one or more
~ contributor license agreements. See the NOTICE file distributed with
~ this work for additional information regarding copyright ownership.
~ The ASF licenses this file to You under the Apache License, Version 2.0
~ (the "License"); you may not use this file except in compliance with
~ the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<!--
If you wish to turn off checking for a section of code, you can put a comment in the source
before and after the section, with the following syntax:
// scalastyle:off
... // stuff that breaks the styles
// scalastyle:on
You can also disable only one rule, by specifying its rule id, as specified in:
http://www.scalastyle.org/rules-0.7.0.html
// scalastyle:off no.finalize
override def finalize(): Unit = ...
// scalastyle:on no.finalize
This file is divided into 3 sections:
(1) rules that we enforce.
(2) rules that we would like to enforce, but haven't cleaned up the codebase to turn on yet
(or we need to make the scalastyle rule more configurable).
(3) rules that we don't want to enforce.
-->
<scalastyle>
<name>Scalastyle standard configuration</name>
<!-- ================================================================================ -->
<!-- rules we enforce -->
<!-- ================================================================================ -->
<check level="error" class="org.scalastyle.file.FileTabChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
<parameters>
<parameter name="header"><![CDATA[/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.SpacesAfterPlusChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.SpacesBeforePlusChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.FileLineLengthChecker" enabled="true">
<parameters>
<parameter name="maxLineLength"><![CDATA[100]]></parameter>
<parameter name="tabSize"><![CDATA[2]]></parameter>
<parameter name="ignoreImports">true</parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ClassNamesChecker" enabled="true">
<parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ObjectNamesChecker" enabled="true">
<parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.PackageObjectNamesChecker" enabled="true">
<parameters><parameter name="regex"><![CDATA[^[a-z][A-Za-z]*$]]></parameter></parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
<parameters><parameter name="maxParameters"><![CDATA[10]]></parameter></parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.NoFinalizeChecker" enabled="false"></check>
<check level="error" class="org.scalastyle.scalariform.CovariantEqualsChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.StructuralTypeChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.UppercaseLChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.IfBraceChecker" enabled="true">
<parameters>
<parameter name="singleLineAllowed"><![CDATA[true]]></parameter>
<parameter name="doubleLineAllowed"><![CDATA[true]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.PublicMethodsHaveTypeChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NonASCIICharacterChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.SpaceAfterCommentStartChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.EnsureSingleSpaceBeforeTokenChecker" enabled="true">
<parameters>
<parameter name="tokens">ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW</parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.EnsureSingleSpaceAfterTokenChecker" enabled="true">
<parameters>
<parameter name="tokens">ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW</parameter>
</parameters>
</check>
<!-- ??? usually shouldn't be checked into the code base. -->
<check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check>
<!-- As of SPARK-7558, all tests in Spark should extend o.a.s.SparkFunSuite instead of FunSuite directly -->
<check customId="funsuite" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
<parameters><parameter name="regex">^FunSuite[A-Za-z]*$</parameter></parameters>
<customMessage>Tests must extend org.apache.spark.SparkFunSuite instead.</customMessage>
</check>
<!-- As of SPARK-7977 all printlns need to be wrapped in '// scalastyle:off/on println' -->
<check customId="println" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
<parameters><parameter name="regex">^println$</parameter></parameters>
<customMessage><![CDATA[Are you sure you want to println? If yes, wrap the code block with
// scalastyle:off println
println(...)
// scalastyle:on println]]></customMessage>
</check>
<check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters><parameter name="regex">@VisibleForTesting</parameter></parameters>
<customMessage><![CDATA[
@VisibleForTesting causes classpath issues. Please note this in the java doc instead (SPARK-11615).
]]></customMessage>
</check>
<check customId="runtimeaddshutdownhook" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters><parameter name="regex">Runtime\.getRuntime\.addShutdownHook</parameter></parameters>
<customMessage><![CDATA[
Are you sure that you want to use Runtime.getRuntime.addShutdownHook? In most cases, you should use
ShutdownHookManager.addShutdownHook instead.
If you must use Runtime.getRuntime.addShutdownHook, wrap the code block with
// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(...)
// scalastyle:on runtimeaddshutdownhook
]]></customMessage>
</check>
<check customId="mutablesynchronizedbuffer" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters><parameter name="regex">mutable\.SynchronizedBuffer</parameter></parameters>
<customMessage><![CDATA[
Are you sure that you want to use mutable.SynchronizedBuffer? In most cases, you should use
java.util.concurrent.ConcurrentLinkedQueue instead.
If you must use mutable.SynchronizedBuffer, wrap the code block with
// scalastyle:off mutablesynchronizedbuffer
mutable.SynchronizedBuffer[...]
// scalastyle:on mutablesynchronizedbuffer
]]></customMessage>
</check>
<check customId="classforname" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters><parameter name="regex">Class\.forName</parameter></parameters>
<customMessage><![CDATA[
Are you sure that you want to use Class.forName? In most cases, you should use Utils.classForName instead.
If you must use Class.forName, wrap the code block with
// scalastyle:off classforname
Class.forName(...)
// scalastyle:on classforname
]]></customMessage>
</check>
<!-- As of SPARK-9613 JavaConversions should be replaced with JavaConverters -->
<check customId="javaconversions" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
<parameters><parameter name="regex">JavaConversions</parameter></parameters>
<customMessage>Instead of importing implicits in scala.collection.JavaConversions._, import
scala.collection.JavaConverters._ and use .asScala / .asJava methods</customMessage>
</check>
<check level="error" class="org.scalastyle.scalariform.ImportOrderChecker" enabled="true">
<parameters>
<parameter name="groups">java,scala,3rdParty,spark</parameter>
<parameter name="group.java">javax?\..*</parameter>
<parameter name="group.scala">scala\..*</parameter>
<parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter>
<parameter name="group.spark">org\.apache\.spark\..*</parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.DisallowSpaceBeforeTokenChecker" enabled="true">
<parameters>
<parameter name="tokens">COMMA</parameter>
</parameters>
</check>
<!-- ================================================================================ -->
<!-- rules we'd like to enforce, but haven't cleaned up the codebase yet -->
<!-- ================================================================================ -->
<!-- We cannot turn the following two on, because it'd fail a lot of string interpolation use cases. -->
<!-- Ideally the following two rules should be configurable to rule out string interpolation. -->
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceBeforeLeftBracketChecker" enabled="false"></check>
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceAfterLeftBracketChecker" enabled="false"></check>
<!-- This breaks symbolic method names so we don't turn it on. -->
<!-- Maybe we should update it to allow basic symbolic names, and then we are good to go. -->
<check level="error" class="org.scalastyle.scalariform.MethodNamesChecker" enabled="false">
<parameters>
<parameter name="regex"><![CDATA[^[a-z][A-Za-z0-9]*$]]></parameter>
</parameters>
</check>
<!-- Should turn this on, but we have a few places that need to be fixed first -->
<check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="false"></check>
<!-- ================================================================================ -->
<!-- rules we don't want -->
<!-- ================================================================================ -->
<check level="error" class="org.scalastyle.scalariform.IllegalImportsChecker" enabled="false">
<parameters><parameter name="illegalImports"><![CDATA[sun._,java.awt._]]></parameter></parameters>
</check>
<!-- We want the opposite of this: NewLineAtEofChecker -->
<check level="error" class="org.scalastyle.file.NoNewLineAtEofChecker" enabled="false"></check>
<!-- This one complains about all kinds of random things. Disable. -->
<check level="error" class="org.scalastyle.scalariform.SimplifyBooleanExpressionChecker" enabled="false"></check>
<!-- We use return quite a bit for control flows and guards -->
<check level="error" class="org.scalastyle.scalariform.ReturnChecker" enabled="false"></check>
<!-- We use null a lot in low level code and to interface with 3rd party code -->
<check level="error" class="org.scalastyle.scalariform.NullChecker" enabled="false"></check>
<!-- Doesn't seem super big deal here ... -->
<check level="error" class="org.scalastyle.scalariform.NoCloneChecker" enabled="false"></check>
<!-- Doesn't seem super big deal here ... -->
<check level="error" class="org.scalastyle.file.FileLengthChecker" enabled="false">
<parameters><parameter name="maxFileLength">800></parameter></parameters>
</check>
<!-- Doesn't seem super big deal here ... -->
<check level="error" class="org.scalastyle.scalariform.NumberOfTypesChecker" enabled="false">
<parameters><parameter name="maxTypes">30</parameter></parameters>
</check>
<!-- Doesn't seem super big deal here ... -->
<check level="error" class="org.scalastyle.scalariform.CyclomaticComplexityChecker" enabled="false">
<parameters><parameter name="maximum">10</parameter></parameters>
</check>
<!-- Doesn't seem super big deal here ... -->
<check level="error" class="org.scalastyle.scalariform.MethodLengthChecker" enabled="false">
<parameters><parameter name="maxLength">50</parameter></parameters>
</check>
<!-- Not exactly feasible to enforce this right now. -->
<!-- It is also infrequent that somebody introduces a new class with a lot of methods. -->
<check level="error" class="org.scalastyle.scalariform.NumberOfMethodsInTypeChecker" enabled="false">
<parameters><parameter name="maxMethods"><![CDATA[30]]></parameter></parameters>
</check>
<!-- Doesn't seem super big deal here, and we have a lot of magic numbers ... -->
<check level="error" class="org.scalastyle.scalariform.MagicNumberChecker" enabled="false">
<parameters><parameter name="ignore">-1,0,1,2,3</parameter></parameters>
</check>
</scalastyle>

View File

@ -0,0 +1,26 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4j-demo</artifactId>
<version>0.1</version>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,120 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.demo.util.DataLoader;
/**
* a simple example of java wrapper for xgboost
*
* @author hzx
*/
public class BasicWalkThrough {
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
if (fPredicts.length != sPredicts.length) {
return false;
}
for (int i = 0; i < fPredicts.length; i++) {
if (!Arrays.equals(fPredicts[i], sPredicts[i])) {
return false;
}
}
return true;
}
public static void main(String[] args) throws IOException, XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
//predict
float[][] predicts = booster.predict(testMat);
//save model to modelPath
File file = new File("./model");
if (!file.exists()) {
file.mkdirs();
}
String modelPath = "./model/xgb.model";
booster.saveModel(modelPath);
//dump model
booster.dumpModel("./model/dump.raw.txt", false);
//dump model with feature map
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
//save dmatrix into binary buffer
testMat.saveBinary("./model/dtest.buffer");
//reload model and data
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
float[][] predicts2 = booster2.predict(testMat2);
//check the two predicts
System.out.println(checkPredicts(predicts, predicts2));
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
DMatrix.SparseType.CSR);
trainMat2.setLabel(spData.labels);
//specify watchList
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
watches2.put("train", trainMat2);
watches2.put("test", testMat2);
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
float[][] predicts3 = booster3.predict(testMat2);
//check predicts
System.out.println(checkPredicts(predicts, predicts3));
}
}

View File

@ -0,0 +1,62 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
/**
* example for start from a initial base prediction
*
* @author hzx
*/
public class BoostFromPrediction {
public static void main(String[] args) throws XGBoostError {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//train xgboost for 1 round
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
float[][] trainPred = booster.predict(trainMat, true);
float[][] testPred = booster.predict(testMat, true);
trainMat.setBaseMargin(trainPred);
testMat.setBaseMargin(testPred);
System.out.println("result of running from initial prediction");
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
}
}

View File

@ -0,0 +1,54 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
import java.io.IOException;
import java.util.HashMap;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
/**
* an example of cross validation
*
* @author hzx
*/
public class CrossValidation {
public static void main(String[] args) throws IOException, XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//set params
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 3);
params.put("silent", 1);
params.put("nthread", 6);
params.put("objective", "binary:logistic");
params.put("gamma", 1.0);
params.put("eval_metric", "error");
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null);
}
}

View File

@ -0,0 +1,167 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import ml.dmlc.xgboost4j.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* an example user define objective and eval
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* he buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation
* function
*
* @author hzx
*/
public class CustomObjective {
/**
* loglikelihoode loss obj function
*/
public static class LogRegObj implements IObjective {
private static final Log logger = LogFactory.getLog(LogRegObj.class);
/**
* simple sigmoid func
*
* @param input
* @return Note: this func is not concern about numerical stability, only used as example
*/
public float sigmoid(float input) {
float val = (float) (1 / (1 + Math.exp(-input)));
return val;
}
public float[][] transform(float[][] predicts) {
int nrow = predicts.length;
float[][] transPredicts = new float[nrow][1];
for (int i = 0; i < nrow; i++) {
transPredicts[i][0] = sigmoid(predicts[i][0]);
}
return transPredicts;
}
@Override
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
int nrow = predicts.length;
List<float[]> gradients = new ArrayList<float[]>();
float[] labels;
try {
labels = dtrain.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return null;
}
float[] grad = new float[nrow];
float[] hess = new float[nrow];
float[][] transPredicts = transform(predicts);
for (int i = 0; i < nrow; i++) {
float predict = transPredicts[i][0];
grad[i] = predict - labels[i];
hess[i] = predict * (1 - predict);
}
gradients.add(grad);
gradients.add(hess);
return gradients;
}
}
/**
* user defined eval function.
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* the buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized
* evaluation function
*/
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for (int i = 0; i < nrow; i++) {
if (labels[i] == 0f && predicts[i][0] > 0) {
error++;
} else if (labels[i] == 1f && predicts[i][0] <= 0) {
error++;
}
}
return error / labels.length;
}
}
public static void main(String[] args) throws XGBoostError {
//load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//load valid mat (svmlight format)
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
//set round
int round = 2;
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//user define obj and eval
IObjective obj = new LogRegObj();
IEvaluation eval = new EvalError();
//train a booster
System.out.println("begin to train the booster model");
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
}
}

View File

@ -0,0 +1,61 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
/**
* simple example for using external memory version
*
* @author hzx
*/
public class ExternalMemory {
public static void main(String[] args) throws XGBoostError {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
//performance notice: set nthread to be the number of your real cpu
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case
// set nthread=4
//param.put("nthread", num_real_cpu);
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
}
}

View File

@ -0,0 +1,70 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.demo.util.CustomEval;
import java.util.HashMap;
/**
* this is an example of fit generalized linear model in xgboost
* basically, we are using linear model, instead of tree for our boosters
*
* @author hzx
*/
public class GeneralizedLinearModel {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model
// alpha is the L1 regularizer
//lambda is the L2 regularizer
//you can also set lambda_bias which is L2 regularizer on the bias term
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("alpha", 0.0001);
params.put("silent", 1);
params.put("objective", "binary:logistic");
params.put("booster", "gblinear");
//normally, you do not need to set eta (step_size)
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
//there could be affection on convergence with parallelization on certain cases
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
//param.put("eta", "0.5");
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//train a booster
int round = 4;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
float[][] predicts = booster.predict(testMat);
CustomEval eval = new CustomEval();
System.out.println("error=" + eval.eval(predicts, testMat));
}
}

View File

@ -0,0 +1,66 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.demo.util.CustomEval;
/**
* predict first ntree
*
* @author hzx
*/
public class PredictFirstNtree {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//train a booster
int round = 3;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
//predict use 1 tree
float[][] predicts1 = booster.predict(testMat, false, 1);
//by default all trees are used to do predict
float[][] predicts2 = booster.predict(testMat);
//use a simple evaluation class to check error result
CustomEval eval = new CustomEval();
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
}
}

View File

@ -0,0 +1,66 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
/**
* predict leaf indices
*
* @author hzx
*/
public class PredictLeafIndices {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//train a booster
int round = 3;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
//predict using first 2 tree
float[][] leafindex = booster.predict(testMat, 2, true);
for (float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
//predict all trees
leafindex = booster.predict(testMat, 0, true);
for (float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
}
}

View File

@ -0,0 +1,60 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo.util;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.IEvaluation;
import ml.dmlc.xgboost4j.XGBoostError;
/**
* a util evaluation class for examples
*
* @author hzx
*/
public class CustomEval implements IEvaluation {
private static final Log logger = LogFactory.getLog(CustomEval.class);
String evalMetric = "custom_error";
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for (int i = 0; i < nrow; i++) {
if (labels[i] == 0f && predicts[i][0] > 0.5) {
error++;
} else if (labels[i] == 1f && predicts[i][0] <= 0.5) {
error++;
}
}
return error / labels.length;
}
}

View File

@ -0,0 +1,123 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo.util;
import org.apache.commons.lang3.ArrayUtils;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
/**
* util class for loading data
*
* @author hzx
*/
public class DataLoader {
public static class DenseData {
public float[] labels;
public float[] data;
public int nrow;
public int ncol;
}
public static class CSRSparseData {
public float[] labels;
public float[] data;
public long[] rowHeaders;
public int[] colIndex;
}
public static DenseData loadCSVFile(String filePath) throws IOException {
DenseData denseData = new DenseData();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
denseData.nrow = 0;
denseData.ncol = -1;
String line;
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
while ((line = reader.readLine()) != null) {
String[] items = line.trim().split(",");
if (items.length == 0) {
continue;
}
denseData.nrow++;
if (denseData.ncol == -1) {
denseData.ncol = items.length - 1;
}
tlabels.add(Float.valueOf(items[items.length - 1]));
for (int i = 0; i < items.length - 1; i++) {
tdata.add(Float.valueOf(items[i]));
}
}
reader.close();
in.close();
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
return denseData;
}
public static CSRSparseData loadSVMFile(String filePath) throws IOException {
CSRSparseData spData = new CSRSparseData();
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
List<Long> theaders = new ArrayList<>();
List<Integer> tindex = new ArrayList<>();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
String line;
long rowheader = 0;
theaders.add(rowheader);
while ((line = reader.readLine()) != null) {
String[] items = line.trim().split(" ");
if (items.length == 0) {
continue;
}
rowheader += items.length - 1;
theaders.add(rowheader);
tlabels.add(Float.valueOf(items[0]));
for (int i = 1; i < items.length; i++) {
String[] tup = items[i].split(":");
assert tup.length == 2;
tdata.add(Float.valueOf(tup[1]));
tindex.add(Integer.valueOf(tup[0]));
}
}
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
return spData;
}
}

View File

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
<packaging>jar</packaging>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.10.3</version>
<configuration>
<show>protected</show>
<nohelp>true</nohelp>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,153 @@
package ml.dmlc.xgboost4j;
import java.io.IOException;
import java.util.Map;
public interface Booster {
/**
* set parameter
*
* @param key param name
* @param value param value
*/
void setParam(String key, String value) throws XGBoostError;
/**
* set parameters
*
* @param params parameters key-value map
*/
void setParams(Map<String, Object> params) throws XGBoostError;
/**
* Update (one iteration)
*
* @param dtrain training data
* @param iter current iteration number
*/
void update(DMatrix dtrain, int iter) throws XGBoostError;
/**
* update with customize obj func
*
* @param dtrain training data
* @param obj customized objective class
*/
void update(DMatrix dtrain, IObjective obj) throws XGBoostError;
/**
* update with give grad and hess
*
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
*/
void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError;
/**
* evaluate with given dmatrixs.
*
* @param evalMatrixs dmatrixs for evaluation
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
*/
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError;
/**
* evaluate with given customized Evaluation class
*
* @param evalMatrixs evaluation matrix
* @param evalNames evaluation names
* @param eval custom evaluator
* @return eval information
*/
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError;
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
*/
float[][] predict(DMatrix data) throws XGBoostError;
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
*/
float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError;
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
*/
float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError;
/**
* Predict with data
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index of
* each sample in each tree. Note that the leaf index of a tree is unique per
* tree, so you may find leaf 1 in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
/**
* save model to modelPath
*
* @param modelPath model path
*/
void saveModel(String modelPath) throws XGBoostError;
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool Controls whether the split statistics are output.
*/
void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError;
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
*/
void dumpModel(String modelPath, String featureMap, boolean withStats)
throws IOException, XGBoostError;
/**
* get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
*/
Map<String, Integer> getFeatureScore() throws XGBoostError ;
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
*/
Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError;
void dispose();
}

View File

@ -0,0 +1,256 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.io.IOException;
/**
* DMatrix for xgboost, similar to the python wrapper xgboost.py
*
* @author hzx
*/
public class DMatrix {
private static final Log logger = LogFactory.getLog(DMatrix.class);
private long handle = 0;
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* sparse matrix type (CSR or CSC)
*/
public static enum SparseType {
CSR,
CSC;
}
public DMatrix(String dataPath) throws XGBoostError {
if (dataPath == null) {
throw new NullPointerException("dataPath: null");
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
handle = out[0];
}
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
} else if (st == SparseType.CSC) {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
} else {
throw new UnknownError("unknow sparsetype");
}
handle = out[0];
}
/**
* create DMatrix from dense matrix
*
* @param data data values
* @param nrow number of rows
* @param ncol number of columns
* @throws XGBoostError native error
*/
public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
handle = out[0];
}
/**
* used for DMatrix slice
*/
protected DMatrix(long handle) {
this.handle = handle;
}
/**
* set label of dmatrix
*
* @param labels labels
* @throws XGBoostError native error
*/
public void setLabel(float[] labels) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
}
/**
* set weight of each instance
*
* @param weights weights
* @throws XGBoostError native error
*/
public void setWeight(float[] weights) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*
* @param baseMargin base margin
* @throws XGBoostError native error
*/
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*
* @param baseMargin base margin
* @throws XGBoostError native error
*/
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
float[] flattenMargin = flatten(baseMargin);
setBaseMargin(flattenMargin);
}
/**
* Set group sizes of DMatrix (used for ranking)
*
* @param group group size as array
* @throws XGBoostError native error
*/
public void setGroup(int[] group) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
}
private float[] getFloatInfo(String field) throws XGBoostError {
float[][] infos = new float[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
return infos[0];
}
private int[] getIntInfo(String field) throws XGBoostError {
int[][] infos = new int[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
return infos[0];
}
/**
* get label values
*
* @return label
* @throws XGBoostError native error
*/
public float[] getLabel() throws XGBoostError {
return getFloatInfo("label");
}
/**
* get weight of the DMatrix
*
* @return weights
* @throws XGBoostError native error
*/
public float[] getWeight() throws XGBoostError {
return getFloatInfo("weight");
}
/**
* get base margin of the DMatrix
*
* @return base margin
* @throws XGBoostError native error
*/
public float[] getBaseMargin() throws XGBoostError {
return getFloatInfo("base_margin");
}
/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
*
* @param rowIndex row index
* @return sliced new DMatrix
* @throws XGBoostError native error
*/
public DMatrix slice(int[] rowIndex) throws XGBoostError {
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
long sHandle = out[0];
DMatrix sMatrix = new DMatrix(sHandle);
return sMatrix;
}
/**
* get the row number of DMatrix
*
* @return number of rows
* @throws XGBoostError native error
*/
public long rowNum() throws XGBoostError {
long[] rowNum = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle, rowNum));
return rowNum[0];
}
/**
* save DMatrix to filePath
*/
public void saveBinary(String filePath) {
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
}
/**
* Get the handle
*/
public long getHandle() {
return handle;
}
/**
* flatten a mat to array
*/
private static float[] flatten(float[][] mat) {
int size = 0;
for (float[] array : mat) size += array.length;
float[] result = new float[size];
int pos = 0;
for (float[] ar : mat) {
System.arraycopy(ar, 0, result, pos, ar.length);
pos += ar.length;
}
return result;
}
@Override
protected void finalize() {
dispose();
}
public synchronized void dispose() {
if (handle != 0) {
XgboostJNI.XGDMatrixFree(handle);
handle = 0;
}
}
}

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package org.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
/** /**
* interface for customized evaluation * interface for customized evaluation
@ -21,21 +21,19 @@ package org.dmlc.xgboost4j;
* @author hzx * @author hzx
*/ */
public interface IEvaluation { public interface IEvaluation {
/** /**
* get evaluate metric * get evaluate metric
* *
* @return evalMetric * @return evalMetric
*/ */
public abstract String getMetric(); String getMetric();
/** /**
* evaluate with predicts and data * evaluate with predicts and data
* *
* @param predicts * @param predicts predictions as array
* predictions as array * @param dmat data matrix to evaluate
* @param dmat * @return result of the metric
* data matrix to evaluate */
* @return result of the metric float eval(float[][] predicts, DMatrix dmat);
*/
public abstract float eval(float[][] predicts, DMatrix dmat);
} }

View File

@ -13,20 +13,22 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package org.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
import java.util.List; import java.util.List;
/** /**
* interface for customize Object function * interface for customize Object function
*
* @author hzx * @author hzx
*/ */
public interface IObjective { public interface IObjective {
/** /**
* user define objective function, return gradient and second order gradient * user define objective function, return gradient and second order gradient
* @param predicts untransformed margin predicts *
* @param dtrain training data * @param predicts untransformed margin predicts
* @return List with two float array, correspond to first order grad and second order grad * @param dtrain training data
*/ * @return List with two float array, correspond to first order grad and second order grad
public abstract List<float[]> getGradient(float[][] predicts, DMatrix dtrain); */
List<float[]> getGradient(float[][] predicts, DMatrix dtrain);
} }

View File

@ -0,0 +1,51 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.io.IOException;
/**
* Error handle for Xgboost.
*/
class JNIErrorHandle {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* Check the return value of C API.
*
* @param ret return valud of xgboostJNI C API call
* @throws XGBoostError native error
*/
static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
}
}
}

View File

@ -0,0 +1,470 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Booster for xgboost, similar to the python wrapper xgboost.py
* but custom obj function and eval function not supported at present.
*
* @author hzx
*/
class JavaBoosterImpl implements Booster {
private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class);
long handle = 0;
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* init Booster from dMatrixs
*
* @param params parameters
* @param dMatrixs DMatrix array
* @throws XGBoostError native error
*/
JavaBoosterImpl(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
init(dMatrixs);
setParam("seed", "0");
setParams(params);
}
/**
* load model from modelPath
*
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws XGBoostError native error
*/
JavaBoosterImpl(Map<String, Object> params, String modelPath) throws XGBoostError {
init(null);
if (modelPath == null) {
throw new NullPointerException("modelPath : null");
}
loadModel(modelPath);
setParam("seed", "0");
setParams(params);
}
private void init(DMatrix[] dMatrixs) throws XGBoostError {
long[] handles = null;
if (dMatrixs != null) {
handles = dmatrixsToHandles(dMatrixs);
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
/**
* set parameter
*
* @param key param name
* @param value param value
* @throws XGBoostError native error
*/
public final void setParam(String key, String value) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
}
/**
* set parameters
*
* @param params parameters key-value map
* @throws XGBoostError native error
*/
public void setParams(Map<String, Object> params) throws XGBoostError {
if (params != null) {
for (Map.Entry<String, Object> entry : params.entrySet()) {
setParam(entry.getKey(), entry.getValue().toString());
}
}
}
/**
* Update (one iteration)
*
* @param dtrain training data
* @param iter current iteration number
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, int iter) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
* update with customize obj func
*
* @param dtrain training data
* @param obj customized objective class
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = predict(dtrain, true);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* update with give grad and hess
*
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws XGBoostError native error
*/
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
if (grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
hess));
}
/**
* evaluate with given dmatrixs.
*
* @param evalMatrixs dmatrixs for evaluation
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dmatrixsToHandles(evalMatrixs);
String[] evalInfo = new String[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
evalInfo));
return evalInfo[0];
}
/**
* evaluate with given customized Evaluation class
*
* @param evalMatrixs evaluation matrix
* @param evalNames evaluation names
* @param eval custom evaluator
* @return eval information
* @throws XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
throws XGBoostError {
String evalInfo = "";
for (int i = 0; i < evalNames.length; i++) {
String evalName = evalNames[i];
DMatrix evalMat = evalMatrixs[i];
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
}
return evalInfo;
}
/**
* base function for Predict
*
* @param data data
* @param outPutMargin output margin
* @param treeLimit limit number of trees
* @param predLeaf prediction minimum to keep leafs
* @return predict results
*/
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
boolean predLeaf) throws XGBoostError {
int optionMask = 0;
if (outPutMargin) {
optionMask = 1;
}
if (predLeaf) {
optionMask = 2;
}
float[][] rawPredicts = new float[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = rawPredicts[0].length / row;
float[][] predicts = new float[row][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data) throws XGBoostError {
return pred(data, false, 0, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
return pred(data, outPutMargin, 0, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
return pred(data, outPutMargin, treeLimit, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index
* of each sample in each tree.
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
* in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
return pred(data, false, treeLimit, predLeaf);
}
/**
* save model to modelPath
*
* @param modelPath model path
*/
public void saveModel(String modelPath) throws XGBoostError{
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveModel(handle, modelPath));
}
private void loadModel(String modelPath) {
XgboostJNI.XGBoosterLoadModel(handle, modelPath);
}
/**
* get the dump of the model as a string array
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
/**
* get the dump of the model as a string array
*
* @param featureMap featureMap file
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
modelInfos));
return modelInfos[0];
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool
* Controls whether the split statistics are output.
* @throws FileNotFoundException file not found
* @throws UnsupportedEncodingException unsupported feature
* @throws IOException error with model writing
* @throws XGBoostError native error
*/
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(withStats);
for (int i = 0; i < modelInfos.length; i++) {
writer.write("booster [" + i + "]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
* @throws FileNotFoundException exception
* @throws UnsupportedEncodingException exception
* @throws IOException exception
* @throws XGBoostError native error
*/
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(featureMap, withStats);
for (int i = 0; i < modelInfos.length; i++) {
writer.write("booster [" + i + "]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore() throws XGBoostError {
String[] modelInfos = getDumpInfo(false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getDumpInfo(featureMap, false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* transfer DMatrix array to handle array (used for native functions)
*
* @param dmatrixs
* @return handle array for input dmatrixs
*/
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for (int i = 0; i < dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
@Override
protected void finalize() throws Throwable {
super.finalize();
dispose();
}
public synchronized void dispose() {
if (handle != 0L) {
XgboostJNI.XGBoosterFree(handle);
handle = 0;
}
}
}

View File

@ -0,0 +1,170 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.io.*;
import java.lang.reflect.Field;
/**
* class to load native library
*
* @author hzx
*/
class NativeLibLoader {
private static final Log logger = LogFactory.getLog(NativeLibLoader.class);
private static boolean initialized = false;
private static final String nativePath = "../lib/";
private static final String nativeResourcePath = "/lib/";
private static final String[] libNames = new String[]{"xgboost4j"};
public static synchronized void initXgBoost() throws IOException {
if (!initialized) {
for (String libName : libNames) {
smartLoad(libName);
}
initialized = true;
}
}
/**
* Loads library from current JAR archive
* <p/>
* The file from JAR is copied into system temporary directory and then loaded.
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p/>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
* {@code path}.
*
* @param path The filename inside JAR as absolute path (beginning with '/'),
* e.g. /package/File.ext
* @throws IOException If temporary file creation or read/write operation fails
* @throws IllegalArgumentException If source file (param path) does not exist
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than
* three characters
*/
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}
// Obtain filename from path
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
}
// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}
// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();
if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}
// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;
// Open and check input stream
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}
// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
// Finally, load the library
System.load(temp.getAbsolutePath());
}
/**
* load native library, this method will first try to load library from java.library.path, then
* try to load library in jar package.
*
* @param libName library path
* @throws IOException exception
*/
private static void smartLoad(String libName) throws IOException {
addNativeDir(nativePath);
try {
System.loadLibrary(libName);
} catch (UnsatisfiedLinkError e) {
try {
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
loadLibraryFromJar(libraryFromJar);
} catch (IOException e1) {
throw e1;
}
}
}
/**
* Add libPath to java.library.path, then native library in libPath would be load properly
*
* @param libPath library path
* @throws IOException exception
*/
private static void addNativeDir(String libPath) throws IOException {
try {
Field field = ClassLoader.class.getDeclaredField("usr_paths");
field.setAccessible(true);
String[] paths = (String[]) field.get(null);
for (String path : paths) {
if (libPath.equals(path)) {
return;
}
}
String[] tmp = new String[paths.length + 1];
System.arraycopy(paths, 0, tmp, 0, paths.length);
tmp[paths.length] = libPath;
field.set(null, tmp);
} catch (IllegalAccessException e) {
logger.error(e.getMessage());
throw new IOException("Failed to get permissions to set library path");
} catch (NoSuchFieldException e) {
logger.error(e.getMessage());
throw new IOException("Failed to get field handle to set library path");
}
}
}

View File

@ -0,0 +1,336 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.*;
/**
* trainer for xgboost
*
* @author hzx
*/
public class XGBoost {
private static final Log logger = LogFactory.getLog(XGBoost.class);
/**
* Train a booster with given parameters.
*
* @param params Booster params.
* @param dtrain Data to be trained.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param obj customized objective (set to null if not used)
* @param eval customized evaluation (set to null if not used)
* @return trained booster
* @throws XGBoostError native error
*/
public static Booster train(Map<String, Object> params, DMatrix dtrain, int round,
Map<String, DMatrix> watches, IObjective obj,
IEvaluation eval) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
names.add(evalEntry.getKey());
mats.add(evalEntry.getValue());
}
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
//collect all data matrixs
DMatrix[] allMats;
if (evalMats != null && evalMats.length > 0) {
allMats = new DMatrix[evalMats.length + 1];
allMats[0] = dtrain;
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
} else {
allMats = new DMatrix[1];
allMats[0] = dtrain;
}
//initialize booster
Booster booster = new JavaBoosterImpl(params, allMats);
//begin to train
for (int iter = 0; iter < round; iter++) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
//evaluation
if (evalMats != null && evalMats.length > 0) {
String evalInfo;
if (eval != null) {
evalInfo = booster.evalSet(evalMats, evalNames, eval);
} else {
evalInfo = booster.evalSet(evalMats, evalNames, iter);
}
logger.info(evalInfo);
}
}
return booster;
}
/**
* init Booster from dMatrixs
*
* @param params parameters
* @param dMatrixs DMatrix array
* @throws XGBoostError native error
*/
public static Booster initBoostingModel(
Map<String, Object> params,
DMatrix[] dMatrixs) throws XGBoostError {
return new JavaBoosterImpl(params, dMatrixs);
}
/**
* load model from modelPath
*
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws XGBoostError native error
*/
public static Booster loadBoostModel(Map<String, Object> params, String modelPath)
throws XGBoostError {
return new JavaBoosterImpl(params, modelPath);
}
/**
* Cross-validation with given paramaters.
*
* @param params Booster params.
* @param data Data to be trained.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
* @param obj customized objective (set to null if not used)
* @param eval customized evaluation (set to null if not used)
* @return evaluation history
* @throws XGBoostError native error
*/
public static String[] crossValiation(
Map<String, Object> params,
DMatrix data,
int round,
int nfold,
String[] metrics,
IObjective obj,
IEvaluation eval) throws XGBoostError {
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
String[] evalHist = new String[round];
String[] results = new String[cvPacks.length];
for (int i = 0; i < round; i++) {
for (CVPack cvPack : cvPacks) {
if (obj != null) {
cvPack.update(obj);
} else {
cvPack.update(i);
}
}
for (int j = 0; j < cvPacks.length; j++) {
if (eval != null) {
results[j] = cvPacks[j].eval(eval);
} else {
results[j] = cvPacks[j].eval(i);
}
}
evalHist[i] = aggCVResults(results);
logger.info(evalHist[i]);
}
return evalHist;
}
/**
* make an n-fold array of CVPack from random indices
*
* @param data original data
* @param nfold num of folds
* @param params booster parameters
* @param evalMetrics Evaluation metrics
* @return CV package array
* @throws XGBoostError native error
*/
private static CVPack[] makeNFold(DMatrix data, int nfold, Map<String, Object> params,
String[] evalMetrics) throws XGBoostError {
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
int step = samples.size() / nfold;
int[] testSlice = new int[step];
int[] trainSlice = new int[samples.size() - step];
int testid, trainid;
CVPack[] cvPacks = new CVPack[nfold];
for (int i = 0; i < nfold; i++) {
testid = 0;
trainid = 0;
for (int j = 0; j < samples.size(); j++) {
if (j > (i * step) && j < (i * step + step) && testid < step) {
testSlice[testid] = samples.get(j);
testid++;
} else {
if (trainid < samples.size() - step) {
trainSlice[trainid] = samples.get(j);
trainid++;
} else {
testSlice[testid] = samples.get(j);
testid++;
}
}
}
DMatrix dtrain = data.slice(trainSlice);
DMatrix dtest = data.slice(testSlice);
CVPack cvPack = new CVPack(dtrain, dtest, params);
//set eval types
if (evalMetrics != null) {
for (String type : evalMetrics) {
cvPack.booster.setParam("eval_metric", type);
}
}
cvPacks[i] = cvPack;
}
return cvPacks;
}
private static List<Integer> genRandPermutationNums(int start, int end) {
List<Integer> samples = new ArrayList<Integer>();
for (int i = start; i < end; i++) {
samples.add(i);
}
Collections.shuffle(samples);
return samples;
}
/**
* Aggregate cross-validation results.
*
* @param results eval info from each data sample
* @return cross-validation eval info
*/
private static String aggCVResults(String[] results) {
Map<String, List<Float>> cvMap = new HashMap<String, List<Float>>();
String aggResult = results[0].split("\t")[0];
for (String result : results) {
String[] items = result.split("\t");
for (int i = 1; i < items.length; i++) {
String[] tup = items[i].split(":");
String key = tup[0];
Float value = Float.valueOf(tup[1]);
if (!cvMap.containsKey(key)) {
cvMap.put(key, new ArrayList<Float>());
}
cvMap.get(key).add(value);
}
}
for (String key : cvMap.keySet()) {
float value = 0f;
for (Float tvalue : cvMap.get(key)) {
value += tvalue;
}
value /= cvMap.get(key).size();
aggResult += String.format("\tcv-%s:%f", key, value);
}
return aggResult;
}
/**
* cross validation package for xgb
*
* @author hzx
*/
private static class CVPack {
DMatrix dtrain;
DMatrix dtest;
DMatrix[] dmats;
String[] names;
Booster booster;
/**
* create an cross validation package
*
* @param dtrain train data
* @param dtest test data
* @param params parameters
* @throws XGBoostError native error
*/
public CVPack(DMatrix dtrain, DMatrix dtest, Map<String, Object> params)
throws XGBoostError {
dmats = new DMatrix[]{dtrain, dtest};
booster = XGBoost.initBoostingModel(params, dmats);
names = new String[]{"train", "test"};
this.dtrain = dtrain;
this.dtest = dtest;
}
/**
* update one iteration
*
* @param iter iteration num
* @throws XGBoostError native error
*/
public void update(int iter) throws XGBoostError {
booster.update(dtrain, iter);
}
/**
* update one iteration
*
* @param obj customized objective
* @throws XGBoostError native error
*/
public void update(IObjective obj) throws XGBoostError {
booster.update(dtrain, obj);
}
/**
* evaluation
*
* @param iter iteration num
* @return evaluation
* @throws XGBoostError native error
*/
public String eval(int iter) throws XGBoostError {
return booster.evalSet(dmats, names, iter);
}
/**
* evaluation
*
* @param eval customized eval
* @return evaluation
* @throws XGBoostError native error
*/
public String eval(IEvaluation eval) throws XGBoostError {
return booster.evalSet(dmats, names, eval);
}
}
}

View File

@ -13,14 +13,15 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package org.dmlc.xgboost4j.util; package ml.dmlc.xgboost4j;
/** /**
* custom error class for xgboost * custom error class for xgboost
*
* @author hzx * @author hzx
*/ */
public class XGBoostError extends Exception{ public class XGBoostError extends Exception {
public XGBoostError(String message) { public XGBoostError(String message) {
super(message); super(message);
} }
} }

View File

@ -13,38 +13,71 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package org.dmlc.xgboost4j.wrapper; package ml.dmlc.xgboost4j;
/** /**
* xgboost jni wrapper functions for xgboost_wrapper.h * xgboost jni wrapper functions for xgboost_wrapper.h
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
*
* @author hzx * @author hzx
*/ */
public class XgboostJNI { class XgboostJNI {
public final static native String XGBGetLastError(); public final static native String XGBGetLastError();
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out);
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, long[] out); public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out); long[] out);
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data,
long[] out);
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
float missing, long[] out);
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out); public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
public final static native int XGDMatrixFree(long handle); public final static native int XGDMatrixFree(long handle);
public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent); public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent);
public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array); public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array);
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array); public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
public final static native int XGDMatrixSetGroup(long handle, int[] group); public final static native int XGDMatrixSetGroup(long handle, int[] group);
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info); public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info); public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
public final static native int XGDMatrixNumRow(long handle, long[] row); public final static native int XGDMatrixNumRow(long handle, long[] row);
public final static native int XGBoosterCreate(long[] handles, long[] out); public final static native int XGBoosterCreate(long[] handles, long[] out);
public final static native int XGBoosterFree(long handle); public final static native int XGBoosterFree(long handle);
public final static native int XGBoosterSetParam(long handle, String name, String value); public final static native int XGBoosterSetParam(long handle, String name, String value);
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain); public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info); public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad,
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, int ntree_limit, float[][] predicts); float[] hess);
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats,
String[] evnames, String[] eval_info);
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
int ntree_limit, float[][] predicts);
public final static native int XGBoosterLoadModel(long handle, String fname); public final static native int XGBoosterLoadModel(long handle, String fname);
public final static native int XGBoosterSaveModel(long handle, String fname); public final static native int XGBoosterSaveModel(long handle, String fname);
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len); public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string); public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings);
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
String[][] out_strings);
} }

View File

@ -0,0 +1,189 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import java.io.IOException
import scala.collection.mutable
import ml.dmlc.xgboost4j.XGBoostError
trait Booster {
/**
* set parameter
*
* @param key param name
* @param value param value
*/
@throws(classOf[XGBoostError])
def setParam(key: String, value: String)
/**
* set parameters
*
* @param params parameters key-value map
*/
@throws(classOf[XGBoostError])
def setParams(params: Map[String, AnyRef])
/**
* Update (one iteration)
*
* @param dtrain training data
* @param iter current iteration number
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, iter: Int)
/**
* update with customize obj func
*
* @param dtrain training data
* @param obj customized objective class
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, obj: ObjectiveTrait)
/**
* update with give grad and hess
*
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
*/
@throws(classOf[XGBoostError])
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float])
/**
* evaluate with given dmatrixs.
*
* @param evalMatrixs dmatrixs for evaluation
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
*/
@throws(classOf[XGBoostError])
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String
/**
* evaluate with given customized Evaluation class
*
* @param evalMatrixs evaluation matrix
* @param evalNames evaluation names
* @param eval custom evaluator
* @return eval information
*/
@throws(classOf[XGBoostError])
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix): Array[Array[Float]]
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]]
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]]
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index of
* each sample in each tree. Note that the leaf index of a tree is unique per
* tree, so you may find leaf 1 in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]]
/**
* save model to modelPath
*
* @param modelPath model path
*/
@throws(classOf[XGBoostError])
def saveModel(modelPath: String)
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool Controls whether the split statistics are output.
*/
@throws(classOf[IOException])
@throws(classOf[XGBoostError])
def dumpModel(modelPath: String, withStats: Boolean)
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
*/
@throws(classOf[IOException])
@throws(classOf[XGBoostError])
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean)
/**
* get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getFeatureScore: mutable.Map[String, Integer]
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getFeatureScore(featureMap: String): mutable.Map[String, Integer]
def dispose
}

View File

@ -0,0 +1,177 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
class DMatrix private(private[scala] val jDMatrix: JDMatrix) {
/**
* init DMatrix from file (svmlight format)
*
* @param dataPath path of data file
* @throws XGBoostError native error
*/
def this(dataPath: String) {
this(new JDMatrix(dataPath))
}
/**
* create DMatrix from sparse matrix
*
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
* @param data non zero values (sequence by row for CSR or by col for CSC)
* @param st sparse matrix type (CSR or CSC)
*/
@throws(classOf[XGBoostError])
def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) {
this(new JDMatrix(headers, indices, data, st))
}
/**
* create DMatrix from dense matrix
*
* @param data data values
* @param nrow number of rows
* @param ncol number of columns
*/
@throws(classOf[XGBoostError])
def this(data: Array[Float], nrow: Int, ncol: Int) {
this(new JDMatrix(data, nrow, ncol))
}
/**
* set label of dmatrix
*
* @param labels labels
*/
@throws(classOf[XGBoostError])
def setLabel(labels: Array[Float]): Unit = {
jDMatrix.setLabel(labels)
}
/**
* set weight of each instance
*
* @param weights weights
*/
@throws(classOf[XGBoostError])
def setWeight(weights: Array[Float]): Unit = {
jDMatrix.setWeight(weights)
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*
* @param baseMargin base margin
*/
@throws(classOf[XGBoostError])
def setBaseMargin(baseMargin: Array[Float]): Unit = {
jDMatrix.setBaseMargin(baseMargin)
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*
* @param baseMargin base margin
*/
@throws(classOf[XGBoostError])
def setBaseMargin(baseMargin: Array[Array[Float]]): Unit = {
jDMatrix.setBaseMargin(baseMargin)
}
/**
* Set group sizes of DMatrix (used for ranking)
*
* @param group group size as array
*/
@throws(classOf[XGBoostError])
def setGroup(group: Array[Int]): Unit = {
jDMatrix.setGroup(group)
}
/**
* get label values
*
* @return label
*/
@throws(classOf[XGBoostError])
def getLabel: Array[Float] = {
jDMatrix.getLabel
}
/**
* get weight of the DMatrix
*
* @return weights
*/
@throws(classOf[XGBoostError])
def getWeight: Array[Float] = {
jDMatrix.getWeight
}
/**
* get base margin of the DMatrix
*
* @return base margin
*/
@throws(classOf[XGBoostError])
def getBaseMargin: Array[Float] = {
jDMatrix.getBaseMargin
}
/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
*
* @param rowIndex row index
* @return sliced new DMatrix
*/
@throws(classOf[XGBoostError])
def slice(rowIndex: Array[Int]): DMatrix = {
new DMatrix(jDMatrix.slice(rowIndex))
}
/**
* get the row number of DMatrix
*
* @return number of rows
*/
@throws(classOf[XGBoostError])
def rowNum: Long = {
jDMatrix.rowNum
}
/**
* save DMatrix to filePath
*
* @param filePath file path
*/
def saveBinary(filePath: String): Unit = {
jDMatrix.saveBinary(filePath)
}
def getHandle: Long = {
jDMatrix.getHandle
}
def delete(): Unit = {
jDMatrix.dispose()
}
}

View File

@ -0,0 +1,38 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.IEvaluation
trait EvalTrait extends IEvaluation {
/**
* get evaluate metric
*
* @return evalMetric
*/
def getMetric: String
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
}

View File

@ -0,0 +1,30 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.IObjective
trait ObjectiveTrait extends IObjective {
/**
* user define objective function, return gradient and second order gradient
*
* @param predicts untransformed margin predicts
* @param dtrain training data
* @return List with two float array, correspond to first order grad and second order grad
*/
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
}

View File

@ -0,0 +1,100 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import scala.collection.JavaConverters._
import scala.collection.mutable
import ml.dmlc.xgboost4j.{Booster => JBooster, IEvaluation, IObjective}
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) extends Booster {
override def setParam(key: String, value: String): Unit = {
booster.setParam(key, value)
}
override def update(dtrain: DMatrix, iter: Int): Unit = {
booster.update(dtrain.jDMatrix, iter)
}
override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
override def dumpModel(modelPath: String, withStats: Boolean): Unit = {
booster.dumpModel(modelPath, withStats)
}
override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
booster.dumpModel(modelPath, featureMap, withStats)
}
override def setParams(params: Map[String, AnyRef]): Unit = {
booster.setParams(params.asJava)
}
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
}
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
}
override def dispose: Unit = {
booster.dispose()
}
override def predict(data: DMatrix): Array[Array[Float]] = {
booster.predict(data.jDMatrix)
}
override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin)
}
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
}
override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
booster.predict(data.jDMatrix, treeLimit, predLeaf)
}
override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
}
override def getFeatureScore: mutable.Map[String, Integer] = {
booster.getFeatureScore.asScala
}
override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureMap).asScala
}
override def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}
override def finalize(): Unit = {
super.finalize()
dispose
}
}

View File

@ -0,0 +1,52 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{XGBoost => JXGBoost}
object XGBoost {
def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int,
watches: Map[String, DMatrix], obj: ObjectiveTrait, eval: EvalTrait): Booster = {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
obj, eval)
new ScalaBoosterImpl(xgboostInJava)
}
def crossValiation(
params: Map[String, AnyRef],
data: DMatrix,
round: Int,
nfold: Int,
metrics: Array[String],
obj: ObjectiveTrait,
eval: EvalTrait): Array[String] = {
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
}
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
new ScalaBoosterImpl(xgboostInJava)
}
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
new ScalaBoosterImpl(xgboostInJava)
}
}

View File

@ -13,7 +13,7 @@
*/ */
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
#include "xgboost4j_wrapper.h" #include "xgboost4j.h"
#include <cstring> #include <cstring>
//helper functions //helper functions
@ -24,7 +24,7 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out); jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
} }
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) { (JNIEnv *jenv, jclass jcls) {
jstring jresult = 0 ; jstring jresult = 0 ;
const char* result = XGBGetLastError(); const char* result = XGBGetLastError();
@ -32,7 +32,7 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastE
return jresult; return jresult;
} }
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
@ -43,11 +43,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromCSR * Method: XGDMatrixCreateFromCSR
* Signature: ([J[J[F)J * Signature: ([J[J[F)J
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
@ -65,11 +65,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromCSC * Method: XGDMatrixCreateFromCSC
* Signature: ([J[J[F)J * Signature: ([J[J[F)J
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
@ -89,11 +89,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromMat * Method: XGDMatrixCreateFromMat
* Signature: ([FIIF)J * Signature: ([FIIF)J
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
jfloat* data = jenv->GetFloatArrayElements(jdata, 0); jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
@ -107,11 +107,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSliceDMatrix * Method: XGDMatrixSliceDMatrix
* Signature: (J[I)J * Signature: (J[I)J
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
@ -128,11 +128,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSlice
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixFree * Method: XGDMatrixFree
* Signature: (J)V * Signature: (J)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) { (JNIEnv *jenv, jclass jcls, jlong jhandle) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
int ret = XGDMatrixFree(handle); int ret = XGDMatrixFree(handle);
@ -140,11 +140,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSaveBinary * Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)V * Signature: (JLjava/lang/String;I)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
@ -154,11 +154,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveB
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetFloatInfo * Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)V * Signature: (JLjava/lang/String;[F)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
@ -173,11 +173,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFl
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetUIntInfo * Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)V * Signature: (JLjava/lang/String;[I)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
@ -192,11 +192,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUI
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetGroup * Method: XGDMatrixSetGroup
* Signature: (J[I)V * Signature: (J[I)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) { (JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* array = jenv->GetIntArrayElements(jarray, NULL); jint* array = jenv->GetIntArrayElements(jarray, NULL);
@ -208,11 +208,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixGetFloatInfo * Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;)[F * Signature: (JLjava/lang/String;)[F
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
@ -230,11 +230,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFl
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixGetUIntInfo * Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;)[I * Signature: (JLjava/lang/String;)[I
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
@ -251,11 +251,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUI
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixNumRow * Method: XGDMatrixNumRow
* Signature: (J)J * Signature: (J)J
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
bst_ulong result[1]; bst_ulong result[1];
@ -265,11 +265,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRo
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterCreate * Method: XGBoosterCreate
* Signature: ([J)J * Signature: ([J)J
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
DMatrixHandle* handles; DMatrixHandle* handles;
bst_ulong len = 0; bst_ulong len = 0;
@ -298,11 +298,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreat
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterFree * Method: XGBoosterFree
* Signature: (J)V * Signature: (J)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) { (JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterFree(handle); return XGBoosterFree(handle);
@ -310,11 +310,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSetParam * Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)V * Signature: (JLjava/lang/String;Ljava/lang/String;)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char* name = jenv->GetStringUTFChars(jname, 0); const char* name = jenv->GetStringUTFChars(jname, 0);
@ -327,11 +327,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetPa
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterUpdateOneIter * Method: XGBoosterUpdateOneIter
* Signature: (JIJ)V * Signature: (JIJ)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain; DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
@ -339,11 +339,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterBoostOneIter * Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V * Signature: (JJ[F[F)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain; DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
@ -358,11 +358,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterEvalOneIter * Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String; * Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle* dmats = 0; DMatrixHandle* dmats = 0;
@ -406,11 +406,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalO
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterPredict * Method: XGBoosterPredict
* Signature: (JJIJ)[F * Signature: (JJIJ)[F
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dmat = (DMatrixHandle) jdmat; DMatrixHandle dmat = (DMatrixHandle) jdmat;
@ -426,11 +426,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredi
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModel * Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)V * Signature: (JLjava/lang/String;)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
@ -441,11 +441,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveModel * Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)V * Signature: (JLjava/lang/String;)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
@ -457,11 +457,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveM
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer * Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)V * Signature: (JJJ)V
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
void *buf = (void*) jbuf; void *buf = (void*) jbuf;
@ -469,11 +469,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetModelRaw * Method: XGBoosterGetModelRaw
* Signature: (J)Ljava/lang/String; * Signature: (J)Ljava/lang/String;
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong len = 0; bst_ulong len = 0;
@ -488,11 +488,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetMo
} }
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterDumpModel * Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I)[Ljava/lang/String; * Signature: (JLjava/lang/String;I)[Ljava/lang/String;
*/ */
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char *fmap = jenv->GetStringUTFChars(jfmap, 0); const char *fmap = jenv->GetStringUTFChars(jfmap, 0);

View File

@ -0,0 +1,221 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class ml_dmlc_xgboost4j_XgboostJNI */
#ifndef _Included_ml_dmlc_xgboost4j_XgboostJNI
#define _Included_ml_dmlc_xgboost4j_XgboostJNI
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBGetLastError
* Signature: ()Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromFile
* Signature: (Ljava/lang/String;I[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *, jclass, jstring, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromCSR
* Signature: ([J[I[F[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromCSC
* Signature: ([J[I[F[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixCreateFromMat
* Signature: ([FIIF[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSliceDMatrix
* Signature: (J[I[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *, jclass, jlong, jintArray, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *, jclass, jlong, jstring, jint);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *, jclass, jlong, jstring, jfloatArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;[[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGDMatrixNumRow
* Signature: (J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
(JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterCreate
* Signature: ([J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *, jclass, jlongArray, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterUpdateOneIter
* Signature: (JIJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *, jclass, jlong, jint, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterPredict
* Signature: (JJII[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *, jclass, jlong, jlong, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetModelRaw
* Signature: (J[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv *, jclass, jlong, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -0,0 +1,138 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
import junit.framework.TestCase;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Test;
import java.util.*;
/**
* test cases for Booster
*
* @author hzx
*/
public class BoosterImplTest {
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for (int i = 0; i < nrow; i++) {
if (labels[i] == 0f && predicts[i][0] > 0) {
error++;
} else if (labels[i] == 1f && predicts[i][0] <= 0) {
error++;
}
}
return error / labels.length;
}
}
@Test
public void testBoosterBasic() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//set params
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//set watchList
HashMap<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
//predict raw output
float[][] predicts = booster.predict(testMat, true);
//eval
IEvaluation eval = new EvalError();
//error must be less than 0.1
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
//test dump model
}
/**
* test cross valiation
*
* @throws XGBoostError
*/
@Test
public void testCV() throws XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//set params
Map<String, Object> param = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 3);
put("silent", 1);
put("nthread", 6);
put("objective", "binary:logistic");
put("gamma", 1.0);
put("eval_metric", "error");
}
};
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics,
null, null);
}
}

View File

@ -0,0 +1,103 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
import junit.framework.TestCase;
import org.junit.Test;
import java.util.Arrays;
import java.util.Random;
/**
* test cases for DMatrix
*
* @author hzx
*/
public class DMatrixTest {
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
//get label
float[] labels = dmat.getLabel();
//check length
TestCase.assertTrue(dmat.rowNum() == labels.length);
//set weights
float[] weights = Arrays.copyOf(labels, labels.length);
dmat.setWeight(weights);
float[] dweights = dmat.getWeight();
TestCase.assertTrue(Arrays.equals(weights, dweights));
}
@Test
public void testCreateFromCSR() throws XGBoostError {
//create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
long[] rowHeaders = new long[]{0, 3, 7, 11};
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
//check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum() == 3);
//test set label
float[] label1 = new float[]{1, 0, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromDenseMatrix() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data0 = new float[nrow * ncol];
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
data0[i] = random.nextFloat();
}
//create label
float[] label0 = new float[nrow];
for (int i = 0; i < nrow; i++) {
label0[i] = random.nextFloat();
}
DMatrix dmat0 = new DMatrix(data0, nrow, ncol);
dmat0.setLabel(label0);
//check
TestCase.assertTrue(dmat0.rowNum() == 10);
TestCase.assertTrue(dmat0.getLabel().length == 10);
//set weights for each instance
float[] weights = new float[nrow];
for (int i = 0; i < nrow; i++) {
weights[i] = random.nextFloat();
}
dmat0.setWeight(weights);
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
}
}

View File

@ -73,10 +73,9 @@ fi
if [ ${TASK} == "java_test" ]; then if [ ${TASK} == "java_test" ]; then
set -e set -e
make java make jvm-packages
cd java cd jvm-packages
./create_wrap.sh ./create_jni.sh
cd xgboost4j
mvn clean install -DskipTests=true mvn clean install -DskipTests=true
mvn test mvn test
fi fi