add java wrapper
This commit is contained in:
15
java/xgboost4j-demo/LICENSE
Normal file
15
java/xgboost4j-demo/LICENSE
Normal file
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
10
java/xgboost4j-demo/README.md
Normal file
10
java/xgboost4j-demo/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
xgboost4j examples
|
||||
====
|
||||
* [Basic walkthrough of wrappers](src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java)
|
||||
* [Cutomize loss function, and evaluation metric](src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java)
|
||||
* [Boosting from existing prediction](src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java)
|
||||
* [Predicting using first n trees](src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java)
|
||||
* [Generalized Linear Model](src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java)
|
||||
* [Cross validation](src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java)
|
||||
* [Predicting leaf indices](src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java)
|
||||
* [External Memory](src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java)
|
||||
36
java/xgboost4j-demo/pom.xml
Normal file
36
java/xgboost4j-demo/pom.xml
Normal file
@@ -0,0 +1,36 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<artifactId>xgboost4j-demo</artifactId>
|
||||
<version>1.0</version>
|
||||
<packaging>jar</packaging>
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.7</maven.compiler.source>
|
||||
<maven.compiler.target>1.7</maven.compiler.target>
|
||||
</properties>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>1.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-io</groupId>
|
||||
<artifactId>commons-io</artifactId>
|
||||
<version>2.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.11</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@@ -0,0 +1,117 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.util.Arrays;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
/**
|
||||
* a simple example of java wrapper for xgboost
|
||||
* @author hzx
|
||||
*/
|
||||
public class BasicWalkThrough {
|
||||
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
|
||||
if(fPredicts.length != sPredicts.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for(int i=0; i<fPredicts.length; i++) {
|
||||
if(!Arrays.equals(fPredicts[i], sPredicts[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) throws UnsupportedEncodingException, IOException {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
|
||||
//predict
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
//save model to modelPath
|
||||
File file = new File("./model");
|
||||
if(!file.exists()) {
|
||||
file.mkdirs();
|
||||
}
|
||||
|
||||
String modelPath = "./model/xgb.model";
|
||||
booster.saveModel(modelPath);
|
||||
|
||||
//dump model
|
||||
booster.dumpModel("./model/dump.raw.txt", false);
|
||||
|
||||
//dump model with feature map
|
||||
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
|
||||
|
||||
//save dmatrix into binary buffer
|
||||
testMat.saveBinary("./model/dtest.buffer");
|
||||
|
||||
//reload model and data
|
||||
Booster booster2 = new Booster(param, "./model/xgb.model");
|
||||
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
||||
float[][] predicts2 = booster2.predict(testMat2);
|
||||
|
||||
|
||||
//check the two predicts
|
||||
System.out.println(checkPredicts(predicts, predicts2));
|
||||
|
||||
System.out.println("start build dmatrix from csr sparse data ...");
|
||||
//build dmatrix from CSR Sparse Matrix
|
||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR);
|
||||
trainMat2.setLabel(spData.labels);
|
||||
|
||||
dmats = new DMatrix[] {trainMat2, testMat};
|
||||
Booster booster3 = Trainer.train(param, trainMat2, round, dmats, evalNames, null, null);
|
||||
float[][] predicts3 = booster3.predict(testMat2);
|
||||
|
||||
//check predicts
|
||||
System.out.println(checkPredicts(predicts, predicts3));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
/**
|
||||
* example for start from a initial base prediction
|
||||
* @author hzx
|
||||
*/
|
||||
public class BoostFromPrediction {
|
||||
public static void main(String[] args) {
|
||||
System.out.println("start running example to start from a initial prediction");
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
|
||||
//train xgboost for 1 round
|
||||
Booster booster = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
|
||||
|
||||
float[][] trainPred = booster.predict(trainMat, true);
|
||||
float[][] testPred = booster.predict(testMat, true);
|
||||
|
||||
trainMat.setBaseMargin(trainPred);
|
||||
testMat.setBaseMargin(testPred);
|
||||
|
||||
System.out.println("result of running from initial prediction");
|
||||
Booster booster2 = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
|
||||
/**
|
||||
* an example of cross validation
|
||||
* @author hzx
|
||||
*/
|
||||
public class CrossValidation {
|
||||
public static void main(String[] args) throws IOException {
|
||||
//load train mat
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
|
||||
//set params
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "3");
|
||||
put("silent", "1");
|
||||
put("nthread", "6");
|
||||
put("objective", "binary:logistic");
|
||||
put("gamma", "1.0");
|
||||
put("eval_metric", "error");
|
||||
}
|
||||
};
|
||||
|
||||
//do 5-fold cross validation
|
||||
int round = 2;
|
||||
int nfold = 5;
|
||||
//set additional eval_metrics
|
||||
String[] metrics = null;
|
||||
|
||||
String[] evalHist = Trainer.crossValiation(param, trainMat, round, nfold, metrics, null, null);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.IObjective;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
/**
|
||||
* an example user define objective and eval
|
||||
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||
* this may make buildin evalution metric not function properly
|
||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
* he buildin evaluation error assumes input is after logistic transformation
|
||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
* @author hzx
|
||||
*/
|
||||
public class CustomObjective {
|
||||
/**
|
||||
* loglikelihoode loss obj function
|
||||
*/
|
||||
public static class LogRegObj implements IObjective {
|
||||
/**
|
||||
* simple sigmoid func
|
||||
* @param input
|
||||
* @return
|
||||
* Note: this func is not concern about numerical stability, only used as example
|
||||
*/
|
||||
public float sigmoid(float input) {
|
||||
float val = (float) (1/(1+Math.exp(-input)));
|
||||
return val;
|
||||
}
|
||||
|
||||
public float[][] transform(float[][] predicts) {
|
||||
int nrow = predicts.length;
|
||||
float[][] transPredicts = new float[nrow][1];
|
||||
|
||||
for(int i=0; i<nrow; i++) {
|
||||
transPredicts[i][0] = sigmoid(predicts[i][0]);
|
||||
}
|
||||
|
||||
return transPredicts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
|
||||
int nrow = predicts.length;
|
||||
List<float[]> gradients = new ArrayList<>();
|
||||
float[] labels = dtrain.getLabel();
|
||||
float[] grad = new float[nrow];
|
||||
float[] hess = new float[nrow];
|
||||
|
||||
float[][] transPredicts = transform(predicts);
|
||||
|
||||
for(int i=0; i<nrow; i++) {
|
||||
float predict = transPredicts[i][0];
|
||||
grad[i] = predict - labels[i];
|
||||
hess[i] = predict * (1 - predict);
|
||||
}
|
||||
|
||||
gradients.add(grad);
|
||||
gradients.add(hess);
|
||||
return gradients;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* user defined eval function.
|
||||
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||
* this may make buildin evalution metric not function properly
|
||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
* the buildin evaluation error assumes input is after logistic transformation
|
||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
*/
|
||||
public static class EvalError implements IEvaluation {
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
public EvalError() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels = dmat.getLabel();
|
||||
int nrow = predicts.length;
|
||||
for(int i=0; i<nrow; i++) {
|
||||
if(labels[i]==0f && predicts[i][0]>0) {
|
||||
error++;
|
||||
}
|
||||
else if(labels[i]==1f && predicts[i][0]<=0) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
return error/labels.length;
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
//load train mat (svmlight format)
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
//load valid mat (svmlight format)
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//set params
|
||||
//set params
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
}
|
||||
};
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//set evaluation data
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "eval"};
|
||||
|
||||
//user define obj and eval
|
||||
IObjective obj = new LogRegObj();
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
//train a booster
|
||||
System.out.println("begin to train the booster model");
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, obj, eval);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
/**
|
||||
* simple example for using external memory version
|
||||
* @author hzx
|
||||
*/
|
||||
public class ExternalMemory {
|
||||
public static void main(String[] args) {
|
||||
//this is the only difference, add a # followed by a cache prefix name
|
||||
//several cache file with the prefix will be generated
|
||||
//currently only support convert from libsvm file
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
||||
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//performance notice: set nthread to be the number of your real cpu
|
||||
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
|
||||
//param.put("nthread", "num_real_cpu");
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
* basically, we are using linear model, instead of tree for our boosters
|
||||
* @author hzx
|
||||
*/
|
||||
public class GeneralizedLinearModel {
|
||||
public static void main(String[] args) {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
//change booster to gblinear, so that we are fitting a linear model
|
||||
// alpha is the L1 regularizer
|
||||
//lambda is the L2 regularizer
|
||||
//you can also set lambda_bias which is L2 regularizer on the bias term
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("alpha", "0.0001");
|
||||
put("silent", "1");
|
||||
put("objective", "binary:logistic");
|
||||
put("booster", "gblinear");
|
||||
}
|
||||
};
|
||||
//normally, you do not need to set eta (step_size)
|
||||
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
||||
//there could be affection on convergence with parallelization on certain cases
|
||||
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
|
||||
//param.put("eta", "0.5");
|
||||
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
|
||||
//train a booster
|
||||
int round = 4;
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
CustomEval eval = new CustomEval();
|
||||
System.out.println("error=" + eval.eval(predicts, testMat));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
|
||||
/**
|
||||
* predict first ntree
|
||||
* @author hzx
|
||||
*/
|
||||
public class PredictFirstNtree {
|
||||
public static void main(String[] args) {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
|
||||
//predict use 1 tree
|
||||
float[][] predicts1 = booster.predict(testMat, false, 1);
|
||||
//by default all trees are used to do predict
|
||||
float[][] predicts2 = booster.predict(testMat);
|
||||
|
||||
//use a simple evaluation class to check error result
|
||||
CustomEval eval = new CustomEval();
|
||||
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
|
||||
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.Arrays;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
/**
|
||||
* predict leaf indices
|
||||
* @author hzx
|
||||
*/
|
||||
public class PredictLeafIndices {
|
||||
public static void main(String[] args) {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
|
||||
//predict using first 2 tree
|
||||
float[][] leafindex = booster.predict(testMat, 2, true);
|
||||
for(float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
|
||||
//predict all trees
|
||||
leafindex = booster.predict(testMat, 0, true);
|
||||
for(float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo.util;
|
||||
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
|
||||
/**
|
||||
* a util evaluation class for examples
|
||||
* @author hzx
|
||||
*/
|
||||
public class CustomEval implements IEvaluation {
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels = dmat.getLabel();
|
||||
int nrow = predicts.length;
|
||||
for(int i=0; i<nrow; i++) {
|
||||
if(labels[i]==0f && predicts[i][0]>0.5) {
|
||||
error++;
|
||||
}
|
||||
else if(labels[i]==1f && predicts[i][0]<=0.5) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
return error/labels.length;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo.util;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
|
||||
/**
|
||||
* util class for loading data
|
||||
* @author hzx
|
||||
*/
|
||||
public class DataLoader {
|
||||
public static class DenseData {
|
||||
public float[] labels;
|
||||
public float[] data;
|
||||
public int nrow;
|
||||
public int ncol;
|
||||
}
|
||||
|
||||
public static class CSRSparseData {
|
||||
public float[] labels;
|
||||
public float[] data;
|
||||
public long[] rowHeaders;
|
||||
public int[] colIndex;
|
||||
}
|
||||
|
||||
public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
||||
DenseData denseData = new DenseData();
|
||||
|
||||
File f = new File(filePath);
|
||||
FileInputStream in = new FileInputStream(f);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||
|
||||
denseData.nrow = 0;
|
||||
denseData.ncol = -1;
|
||||
String line;
|
||||
List<Float> tlabels = new ArrayList<>();
|
||||
List<Float> tdata = new ArrayList<>();
|
||||
|
||||
while((line=reader.readLine()) != null) {
|
||||
String[] items = line.trim().split(",");
|
||||
if(items.length==0) {
|
||||
continue;
|
||||
}
|
||||
denseData.nrow++;
|
||||
if(denseData.ncol == -1) {
|
||||
denseData.ncol = items.length - 1;
|
||||
}
|
||||
|
||||
tlabels.add(Float.valueOf(items[items.length-1]));
|
||||
for(int i=0; i<items.length-1; i++) {
|
||||
tdata.add(Float.valueOf(items[i]));
|
||||
}
|
||||
}
|
||||
|
||||
reader.close();
|
||||
in.close();
|
||||
|
||||
Float[] flabels = (Float[]) tlabels.toArray();
|
||||
denseData.labels = ArrayUtils.toPrimitive(flabels);
|
||||
Float[] fdata = (Float[]) tdata.toArray();
|
||||
denseData.data = ArrayUtils.toPrimitive(fdata);
|
||||
|
||||
return denseData;
|
||||
}
|
||||
|
||||
public static CSRSparseData loadSVMFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
||||
CSRSparseData spData = new CSRSparseData();
|
||||
|
||||
List<Float> tlabels = new ArrayList<>();
|
||||
List<Float> tdata = new ArrayList<>();
|
||||
List<Long> theaders = new ArrayList<>();
|
||||
List<Integer> tindex = new ArrayList<>();
|
||||
|
||||
File f = new File(filePath);
|
||||
FileInputStream in = new FileInputStream(f);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||
|
||||
String line;
|
||||
long rowheader = 0;
|
||||
theaders.add(rowheader);
|
||||
while((line=reader.readLine()) != null) {
|
||||
String[] items = line.trim().split(" ");
|
||||
if(items.length==0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
rowheader += items.length - 1;
|
||||
theaders.add(rowheader);
|
||||
tlabels.add(Float.valueOf(items[0]));
|
||||
|
||||
for(int i=1; i<items.length; i++) {
|
||||
String[] tup = items[i].split(":");
|
||||
assert tup.length == 2;
|
||||
|
||||
tdata.add(Float.valueOf(tup[1]));
|
||||
tindex.add(Integer.valueOf(tup[0]));
|
||||
}
|
||||
}
|
||||
|
||||
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
|
||||
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
|
||||
|
||||
return spData;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user