[DOC-JVM] Refactor JVM docs
This commit is contained in:
15
jvm-packages/xgboost4j-example/LICENSE
Normal file
15
jvm-packages/xgboost4j-example/LICENSE
Normal file
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
18
jvm-packages/xgboost4j-example/README.md
Normal file
18
jvm-packages/xgboost4j-example/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
XGBoost4J Code Examples
|
||||
=======================
|
||||
|
||||
## Java API
|
||||
* [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java)
|
||||
* [Cutomize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java)
|
||||
* [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java)
|
||||
* [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java)
|
||||
* [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java)
|
||||
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
|
||||
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
|
||||
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
|
||||
|
||||
## Spark API
|
||||
* [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/spark/example/DistTrainWithSpark.scala)
|
||||
|
||||
## Flink API
|
||||
* [Distributed Training with Flink](src/main/scala/ml/dmlc/xgboost4j/scala/flink/example/DistTrainWithFlink.scala)
|
||||
42
jvm-packages/xgboost4j-example/pom.xml
Normal file
42
jvm-packages/xgboost4j-example/pom.xml
Normal file
@@ -0,0 +1,42 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>0.5</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-example</artifactId>
|
||||
<version>0.5</version>
|
||||
<packaging>jar</packaging>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<skipAssembly>false</skipAssembly>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark</artifactId>
|
||||
<version>0.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-flink</artifactId>
|
||||
<version>0.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@@ -0,0 +1,120 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.example.util.DataLoader;
|
||||
|
||||
/**
|
||||
* a simple example of java wrapper for xgboost
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class BasicWalkThrough {
|
||||
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
|
||||
if (fPredicts.length != sPredicts.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < fPredicts.length; i++) {
|
||||
if (!Arrays.equals(fPredicts[i], sPredicts[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
//save model to modelPath
|
||||
File file = new File("./model");
|
||||
if (!file.exists()) {
|
||||
file.mkdirs();
|
||||
}
|
||||
|
||||
String modelPath = "./model/xgb.model";
|
||||
booster.saveModel(modelPath);
|
||||
|
||||
//dump model
|
||||
booster.getModelDump("./model/dump.raw.txt", false);
|
||||
|
||||
//dump model with feature map
|
||||
booster.getModelDump("../../demo/data/featmap.txt", false);
|
||||
|
||||
//save dmatrix into binary buffer
|
||||
testMat.saveBinary("./model/dtest.buffer");
|
||||
|
||||
//reload model and data
|
||||
Booster booster2 = XGBoost.loadModel("./model/xgb.model");
|
||||
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
||||
float[][] predicts2 = booster2.predict(testMat2);
|
||||
|
||||
|
||||
//check the two predicts
|
||||
System.out.println(checkPredicts(predicts, predicts2));
|
||||
|
||||
System.out.println("start build dmatrix from csr sparse data ...");
|
||||
//build dmatrix from CSR Sparse Matrix
|
||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
DMatrix.SparseType.CSR);
|
||||
trainMat2.setLabel(spData.labels);
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
|
||||
watches2.put("train", trainMat2);
|
||||
watches2.put("test", testMat2);
|
||||
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
|
||||
float[][] predicts3 = booster3.predict(testMat2);
|
||||
|
||||
//check predicts
|
||||
System.out.println(checkPredicts(predicts, predicts3));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* example for start from a initial base prediction
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class BoostFromPrediction {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
System.out.println("start running example to start from a initial prediction");
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//train xgboost for 1 round
|
||||
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
|
||||
|
||||
float[][] trainPred = booster.predict(trainMat, true);
|
||||
float[][] testPred = booster.predict(testMat, true);
|
||||
|
||||
trainMat.setBaseMargin(trainPred);
|
||||
testMat.setBaseMargin(testPred);
|
||||
|
||||
System.out.println("result of running from initial prediction");
|
||||
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* an example of cross validation
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class CrossValidation {
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
//load train mat
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
|
||||
//set params
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 3);
|
||||
params.put("silent", 1);
|
||||
params.put("nthread", 6);
|
||||
params.put("objective", "binary:logistic");
|
||||
params.put("gamma", 1.0);
|
||||
params.put("eval_metric", "error");
|
||||
|
||||
//do 5-fold cross validation
|
||||
int round = 2;
|
||||
int nfold = 5;
|
||||
//set additional eval_metrics
|
||||
String[] metrics = null;
|
||||
|
||||
String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null,
|
||||
null);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.*;
|
||||
|
||||
/**
|
||||
* an example user define objective and eval
|
||||
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||
* this may make buildin evalution metric not function properly
|
||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
* he buildin evaluation error assumes input is after logistic transformation
|
||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation
|
||||
* function
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class CustomObjective {
|
||||
/**
|
||||
* loglikelihoode loss obj function
|
||||
*/
|
||||
public static class LogRegObj implements IObjective {
|
||||
private static final Log logger = LogFactory.getLog(LogRegObj.class);
|
||||
|
||||
/**
|
||||
* simple sigmoid func
|
||||
*
|
||||
* @param input
|
||||
* @return Note: this func is not concern about numerical stability, only used as example
|
||||
*/
|
||||
public float sigmoid(float input) {
|
||||
float val = (float) (1 / (1 + Math.exp(-input)));
|
||||
return val;
|
||||
}
|
||||
|
||||
public float[][] transform(float[][] predicts) {
|
||||
int nrow = predicts.length;
|
||||
float[][] transPredicts = new float[nrow][1];
|
||||
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
transPredicts[i][0] = sigmoid(predicts[i][0]);
|
||||
}
|
||||
|
||||
return transPredicts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
|
||||
int nrow = predicts.length;
|
||||
List<float[]> gradients = new ArrayList<float[]>();
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dtrain.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return null;
|
||||
}
|
||||
float[] grad = new float[nrow];
|
||||
float[] hess = new float[nrow];
|
||||
|
||||
float[][] transPredicts = transform(predicts);
|
||||
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
float predict = transPredicts[i][0];
|
||||
grad[i] = predict - labels[i];
|
||||
hess[i] = predict * (1 - predict);
|
||||
}
|
||||
|
||||
gradients.add(grad);
|
||||
gradients.add(hess);
|
||||
return gradients;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* user defined eval function.
|
||||
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||
* this may make buildin evalution metric not function properly
|
||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
* the buildin evaluation error assumes input is after logistic transformation
|
||||
* Take this in mind when you use the customization, and maybe you need write customized
|
||||
* evaluation function
|
||||
*/
|
||||
public static class EvalError implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
public EvalError() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
if (labels[i] == 0f && predicts[i][0] > 0) {
|
||||
error++;
|
||||
} else if (labels[i] == 1f && predicts[i][0] <= 0) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
return error / labels.length;
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//load train mat (svmlight format)
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
//load valid mat (svmlight format)
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//user define obj and eval
|
||||
IObjective obj = new LogRegObj();
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
//train a booster
|
||||
System.out.println("begin to train the booster model");
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* simple example for using external memory version
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class ExternalMemory {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//this is the only difference, add a # followed by a cache prefix name
|
||||
//several cache file with the prefix will be generated
|
||||
//currently only support convert from libsvm file
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//performance notice: set nthread to be the number of your real cpu
|
||||
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case
|
||||
// set nthread=4
|
||||
//param.put("nthread", num_real_cpu);
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
* basically, we are using linear model, instead of tree for our boosters
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class GeneralizedLinearModel {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
//change booster to gblinear, so that we are fitting a linear model
|
||||
// alpha is the L1 regularizer
|
||||
//lambda is the L2 regularizer
|
||||
//you can also set lambda_bias which is L2 regularizer on the bias term
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("alpha", 0.0001);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
params.put("booster", "gblinear");
|
||||
|
||||
//normally, you do not need to set eta (step_size)
|
||||
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
||||
//there could be affection on convergence with parallelization on certain cases
|
||||
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
|
||||
//param.put("eta", "0.5");
|
||||
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//train a booster
|
||||
int round = 4;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
CustomEval eval = new CustomEval();
|
||||
System.out.println("error=" + eval.eval(predicts, testMat));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
|
||||
/**
|
||||
* predict first ntree
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class PredictFirstNtree {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict use 1 tree
|
||||
float[][] predicts1 = booster.predict(testMat, false, 1);
|
||||
//by default all trees are used to do predict
|
||||
float[][] predicts2 = booster.predict(testMat);
|
||||
|
||||
//use a simple evaluation class to check error result
|
||||
CustomEval eval = new CustomEval();
|
||||
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
|
||||
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* predict leaf indices
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class PredictLeafIndices {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict using first 2 tree
|
||||
float[][] leafindex = booster.predictLeaf(testMat, 2);
|
||||
for (float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
|
||||
//predict all trees
|
||||
leafindex = booster.predictLeaf(testMat, 0);
|
||||
for (float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.util;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.IEvaluation;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* a util evaluation class for examples
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class CustomEval implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(CustomEval.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
if (labels[i] == 0f && predicts[i][0] > 0.5) {
|
||||
error++;
|
||||
} else if (labels[i] == 1f && predicts[i][0] <= 0.5) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
return error / labels.length;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.util;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
|
||||
/**
|
||||
* util class for loading data
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class DataLoader {
|
||||
public static class DenseData {
|
||||
public float[] labels;
|
||||
public float[] data;
|
||||
public int nrow;
|
||||
public int ncol;
|
||||
}
|
||||
|
||||
public static class CSRSparseData {
|
||||
public float[] labels;
|
||||
public float[] data;
|
||||
public long[] rowHeaders;
|
||||
public int[] colIndex;
|
||||
}
|
||||
|
||||
public static DenseData loadCSVFile(String filePath) throws IOException {
|
||||
DenseData denseData = new DenseData();
|
||||
|
||||
File f = new File(filePath);
|
||||
FileInputStream in = new FileInputStream(f);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||
|
||||
denseData.nrow = 0;
|
||||
denseData.ncol = -1;
|
||||
String line;
|
||||
List<Float> tlabels = new ArrayList<>();
|
||||
List<Float> tdata = new ArrayList<>();
|
||||
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] items = line.trim().split(",");
|
||||
if (items.length == 0) {
|
||||
continue;
|
||||
}
|
||||
denseData.nrow++;
|
||||
if (denseData.ncol == -1) {
|
||||
denseData.ncol = items.length - 1;
|
||||
}
|
||||
|
||||
tlabels.add(Float.valueOf(items[items.length - 1]));
|
||||
for (int i = 0; i < items.length - 1; i++) {
|
||||
tdata.add(Float.valueOf(items[i]));
|
||||
}
|
||||
}
|
||||
|
||||
reader.close();
|
||||
in.close();
|
||||
|
||||
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||
|
||||
return denseData;
|
||||
}
|
||||
|
||||
public static CSRSparseData loadSVMFile(String filePath) throws IOException {
|
||||
CSRSparseData spData = new CSRSparseData();
|
||||
|
||||
List<Float> tlabels = new ArrayList<>();
|
||||
List<Float> tdata = new ArrayList<>();
|
||||
List<Long> theaders = new ArrayList<>();
|
||||
List<Integer> tindex = new ArrayList<>();
|
||||
|
||||
File f = new File(filePath);
|
||||
FileInputStream in = new FileInputStream(f);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||
|
||||
String line;
|
||||
long rowheader = 0;
|
||||
theaders.add(rowheader);
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] items = line.trim().split(" ");
|
||||
if (items.length == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
rowheader += items.length - 1;
|
||||
theaders.add(rowheader);
|
||||
tlabels.add(Float.valueOf(items[0]));
|
||||
|
||||
for (int i = 1; i < items.length; i++) {
|
||||
String[] tup = items[i].split(":");
|
||||
assert tup.length == 2;
|
||||
|
||||
tdata.add(Float.valueOf(tup[1]));
|
||||
tindex.add(Integer.valueOf(tup[0]));
|
||||
}
|
||||
}
|
||||
|
||||
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
|
||||
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
|
||||
|
||||
return spData;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.flink.example
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||
import org.apache.flink.api.scala._
|
||||
import org.apache.flink.api.scala.ExecutionEnvironment
|
||||
import org.apache.flink.ml.MLUtils
|
||||
|
||||
object DistTrainWithFlink {
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
// read trainining data
|
||||
val trainData =
|
||||
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||
// define parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
// number of iterations
|
||||
val round = 2
|
||||
// train the model
|
||||
val model = XGBoost.train(paramMap, trainData, round)
|
||||
val predTrain = model.predict(trainData.map{x => x.vector})
|
||||
model.saveModelToHadoop("file:///path/to/xgboost.model")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.example
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.io.Source
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.linalg.DenseVector
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost
|
||||
|
||||
|
||||
object DistTrainWithSpark {
|
||||
|
||||
private def readFile(filePath: String): List[LabeledPoint] = {
|
||||
val file = Source.fromFile(new File(filePath))
|
||||
val sampleList = new ListBuffer[LabeledPoint]
|
||||
for (sample <- file.getLines()) {
|
||||
sampleList += fromSVMStringToLabeledPoint(sample)
|
||||
}
|
||||
sampleList.toList
|
||||
}
|
||||
|
||||
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
|
||||
val labelAndFeatures = line.split(" ")
|
||||
val label = labelAndFeatures(0).toInt
|
||||
val features = labelAndFeatures.tail
|
||||
val denseFeature = new Array[Double](129)
|
||||
for (feature <- features) {
|
||||
val idAndValue = feature.split(":")
|
||||
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
|
||||
}
|
||||
LabeledPoint(label, new DenseVector(denseFeature))
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
|
||||
if (args.length != 4) {
|
||||
println(
|
||||
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
val sc = new SparkContext()
|
||||
val inputTrainPath = args(2)
|
||||
val inputTestPath = args(3)
|
||||
val trainingLabeledPoints = readFile(inputTrainPath)
|
||||
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
|
||||
val testLabeledPoints = readFile(inputTestPath).iterator
|
||||
val testMatrix = new DMatrix(testLabeledPoints, null)
|
||||
val booster = XGBoost.train(trainRDD,
|
||||
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
|
||||
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
|
||||
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user