make some fix
This commit is contained in:
parent
f91a098770
commit
c110111f52
@ -17,11 +17,11 @@ core of this wrapper is two classes:
|
|||||||
|
|
||||||
## build native library
|
## build native library
|
||||||
|
|
||||||
for windows: open the xgboost.sln in windows folder, you will found the xgboostjavawrapper project, you should do the following steps to build wrapper library:
|
for windows: open the xgboost.sln in "../windows" folder, you will found the xgboostjavawrapper project, you should do the following steps to build wrapper library:
|
||||||
* Select x64/win32 and Release in build
|
* Select x64/win32 and Release in build
|
||||||
* (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboostjavawrapper project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path.
|
* (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboostjavawrapper project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path.
|
||||||
* rebuild all
|
* rebuild all
|
||||||
* move the dll "xgboostjavawrapper.dll" to "xgboost4j/src/main/resources/lib/"(you may need to create this folder if necessary.)
|
* double click "create_wrap.bat" to set library to proper place
|
||||||
|
|
||||||
for linux:
|
for linux:
|
||||||
* make sure you have installed jdk and `JAVA_HOME` has been setted properly
|
* make sure you have installed jdk and `JAVA_HOME` has been setted properly
|
||||||
|
|||||||
20
java/create_wrap.bat
Normal file
20
java/create_wrap.bat
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
echo "move native library"
|
||||||
|
set libsource=..\windows\x64\Release\xgboostjavawrapper.dll
|
||||||
|
|
||||||
|
if not exist %libsource% (
|
||||||
|
goto end
|
||||||
|
)
|
||||||
|
|
||||||
|
set libfolder=xgboost4j\src\main\resources\lib
|
||||||
|
set libpath=%libfolder%\xgboostjavawrapper.dll
|
||||||
|
if not exist %libfolder% (mkdir %libfolder%)
|
||||||
|
if exist %libpath% (del %libpath%)
|
||||||
|
move %libsource% %libfolder%
|
||||||
|
echo complete
|
||||||
|
pause
|
||||||
|
exit
|
||||||
|
|
||||||
|
:end
|
||||||
|
echo "source library not found, please build it first from ..\windows\xgboost.sln"
|
||||||
|
pause
|
||||||
|
exit
|
||||||
@ -6,7 +6,7 @@ echo "move native lib"
|
|||||||
|
|
||||||
libPath="xgboost4j/src/main/resources/lib"
|
libPath="xgboost4j/src/main/resources/lib"
|
||||||
if [ ! -d "$libPath" ]; then
|
if [ ! -d "$libPath" ]; then
|
||||||
mkdir "$libPath"
|
mkdir -p "$libPath"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
rm -f xgboost4j/src/main/resources/lib/libxgboostjavawrapper.so
|
rm -f xgboost4j/src/main/resources/lib/libxgboostjavawrapper.so
|
||||||
|
|||||||
@ -82,9 +82,9 @@ import org.dmlc.xgboost4j.util.Params;
|
|||||||
```java
|
```java
|
||||||
Params params = new Params() {
|
Params params = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "2");
|
put("max_depth", 2);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
put("eval_metric", "logloss");
|
put("eval_metric", "logloss");
|
||||||
}
|
}
|
||||||
@ -94,9 +94,9 @@ Params params = new Params() {
|
|||||||
```java
|
```java
|
||||||
Params params = new Params() {
|
Params params = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "2");
|
put("max_depth", 2);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
put("eval_metric", "logloss");
|
put("eval_metric", "logloss");
|
||||||
put("eval_metric", "error");
|
put("eval_metric", "error");
|
||||||
@ -110,16 +110,19 @@ With parameters and data, you are able to train a booster model.
|
|||||||
```java
|
```java
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.WatchList;
|
||||||
```
|
```
|
||||||
|
|
||||||
* Training
|
* Training
|
||||||
```java
|
```java
|
||||||
DMatrix trainMat = new DMatrix("train.svm.txt");
|
DMatrix trainMat = new DMatrix("train.svm.txt");
|
||||||
DMatrix validMat = new DMatrix("valid.svm.txt");
|
DMatrix validMat = new DMatrix("valid.svm.txt");
|
||||||
DMatrix[] evalMats = new DMatrix[] {trainMat, validMat};
|
//specifiy a watchList to see the performance
|
||||||
String[] evalNames = new String[] {"train", "valid"};
|
WatchList watchs = new WatchList();
|
||||||
|
watchs.put("train", trainMat);
|
||||||
|
watchs.put("test", testMat);
|
||||||
int round = 2;
|
int round = 2;
|
||||||
Booster booster = Trainer.train(params, trainMat, round, evalMats, evalNames, null, null);
|
Booster booster = Trainer.train(params, trainMat, round, watchs, null, null);
|
||||||
```
|
```
|
||||||
|
|
||||||
* Saving model
|
* Saving model
|
||||||
@ -139,8 +142,8 @@ booster.dumpModel("modelInfo.txt", "featureMap.txt", false)
|
|||||||
```java
|
```java
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("nthread", "6");
|
put("nthread", 6);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Booster booster = new Booster(param, "model.bin");
|
Booster booster = new Booster(param, "model.bin");
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
|||||||
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.WatchList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* a simple example of java wrapper for xgboost
|
* a simple example of java wrapper for xgboost
|
||||||
@ -53,22 +54,23 @@ public class BasicWalkThrough {
|
|||||||
//specify parameters
|
//specify parameters
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "2");
|
put("max_depth", 2);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//specify evaluate datasets and evaluate names
|
//specify watchList
|
||||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
WatchList watchs = new WatchList();
|
||||||
String[] evalNames = new String[] {"train", "test"};
|
watchs.put("train", trainMat);
|
||||||
|
watchs.put("test", testMat);
|
||||||
|
|
||||||
//set round
|
//set round
|
||||||
int round = 2;
|
int round = 2;
|
||||||
|
|
||||||
//train a boost model
|
//train a boost model
|
||||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||||
|
|
||||||
//predict
|
//predict
|
||||||
float[][] predicts = booster.predict(testMat);
|
float[][] predicts = booster.predict(testMat);
|
||||||
@ -107,8 +109,11 @@ public class BasicWalkThrough {
|
|||||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR);
|
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR);
|
||||||
trainMat2.setLabel(spData.labels);
|
trainMat2.setLabel(spData.labels);
|
||||||
|
|
||||||
dmats = new DMatrix[] {trainMat2, testMat};
|
//specify watchList
|
||||||
Booster booster3 = Trainer.train(param, trainMat2, round, dmats, evalNames, null, null);
|
WatchList watchs2 = new WatchList();
|
||||||
|
watchs2.put("train", trainMat2);
|
||||||
|
watchs2.put("test", testMat);
|
||||||
|
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
|
||||||
float[][] predicts3 = booster3.predict(testMat2);
|
float[][] predicts3 = booster3.predict(testMat2);
|
||||||
|
|
||||||
//check predicts
|
//check predicts
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import org.dmlc.xgboost4j.Booster;
|
|||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.WatchList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* example for start from a initial base prediction
|
* example for start from a initial base prediction
|
||||||
@ -35,19 +36,20 @@ public class BoostFromPrediction {
|
|||||||
//specify parameters
|
//specify parameters
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "2");
|
put("max_depth", 2);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//specify evaluate datasets and evaluate names
|
//specify watchList
|
||||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
WatchList watchs = new WatchList();
|
||||||
String[] evalNames = new String[] {"train", "test"};
|
watchs.put("train", trainMat);
|
||||||
|
watchs.put("test", testMat);
|
||||||
|
|
||||||
//train xgboost for 1 round
|
//train xgboost for 1 round
|
||||||
Booster booster = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
|
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);
|
||||||
|
|
||||||
float[][] trainPred = booster.predict(trainMat, true);
|
float[][] trainPred = booster.predict(trainMat, true);
|
||||||
float[][] testPred = booster.predict(testMat, true);
|
float[][] testPred = booster.predict(testMat, true);
|
||||||
@ -56,6 +58,6 @@ public class BoostFromPrediction {
|
|||||||
testMat.setBaseMargin(testPred);
|
testMat.setBaseMargin(testPred);
|
||||||
|
|
||||||
System.out.println("result of running from initial prediction");
|
System.out.println("result of running from initial prediction");
|
||||||
Booster booster2 = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
|
Booster booster2 = Trainer.train(param, trainMat, 1, watchs, null, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,12 +32,12 @@ public class CrossValidation {
|
|||||||
//set params
|
//set params
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "3");
|
put("max_depth", 3);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("nthread", "6");
|
put("nthread", 6);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
put("gamma", "1.0");
|
put("gamma", 1.0);
|
||||||
put("eval_metric", "error");
|
put("eval_metric", "error");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -16,7 +16,6 @@
|
|||||||
package org.dmlc.xgboost4j.demo;
|
package org.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.IEvaluation;
|
import org.dmlc.xgboost4j.IEvaluation;
|
||||||
@ -24,6 +23,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
|||||||
import org.dmlc.xgboost4j.IObjective;
|
import org.dmlc.xgboost4j.IObjective;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.WatchList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* an example user define objective and eval
|
* an example user define objective and eval
|
||||||
@ -130,18 +130,19 @@ public class CustomObjective {
|
|||||||
//set params
|
//set params
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "2");
|
put("max_depth", 2);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//set round
|
//set round
|
||||||
int round = 2;
|
int round = 2;
|
||||||
|
|
||||||
//set evaluation data
|
//specify watchList
|
||||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
WatchList watchs = new WatchList();
|
||||||
String[] evalNames = new String[] {"train", "eval"};
|
watchs.put("train", trainMat);
|
||||||
|
watchs.put("test", testMat);
|
||||||
|
|
||||||
//user define obj and eval
|
//user define obj and eval
|
||||||
IObjective obj = new LogRegObj();
|
IObjective obj = new LogRegObj();
|
||||||
@ -149,6 +150,6 @@ public class CustomObjective {
|
|||||||
|
|
||||||
//train a booster
|
//train a booster
|
||||||
System.out.println("begin to train the booster model");
|
System.out.println("begin to train the booster model");
|
||||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, obj, eval);
|
Booster booster = Trainer.train(param, trainMat, round, watchs, obj, eval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import org.dmlc.xgboost4j.Booster;
|
|||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.WatchList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* simple example for using external memory version
|
* simple example for using external memory version
|
||||||
@ -35,25 +36,26 @@ public class ExternalMemory {
|
|||||||
//specify parameters
|
//specify parameters
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "2");
|
put("max_depth", 2);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//performance notice: set nthread to be the number of your real cpu
|
//performance notice: set nthread to be the number of your real cpu
|
||||||
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
|
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
|
||||||
//param.put("nthread", "num_real_cpu");
|
//param.put("nthread", num_real_cpu);
|
||||||
|
|
||||||
//specify evaluate datasets and evaluate names
|
//specify watchList
|
||||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
WatchList watchs = new WatchList();
|
||||||
String[] evalNames = new String[] {"train", "test"};
|
watchs.put("train", trainMat);
|
||||||
|
watchs.put("test", testMat);
|
||||||
|
|
||||||
//set round
|
//set round
|
||||||
int round = 2;
|
int round = 2;
|
||||||
|
|
||||||
//train a boost model
|
//train a boost model
|
||||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
|||||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.WatchList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this is an example of fit generalized linear model in xgboost
|
* this is an example of fit generalized linear model in xgboost
|
||||||
@ -39,8 +40,8 @@ public class GeneralizedLinearModel {
|
|||||||
//you can also set lambda_bias which is L2 regularizer on the bias term
|
//you can also set lambda_bias which is L2 regularizer on the bias term
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("alpha", "0.0001");
|
put("alpha", 0.0001);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
put("booster", "gblinear");
|
put("booster", "gblinear");
|
||||||
}
|
}
|
||||||
@ -52,13 +53,14 @@ public class GeneralizedLinearModel {
|
|||||||
//param.put("eta", "0.5");
|
//param.put("eta", "0.5");
|
||||||
|
|
||||||
|
|
||||||
//specify evaluate datasets and evaluate names
|
//specify watchList
|
||||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
WatchList watchs = new WatchList();
|
||||||
String[] evalNames = new String[] {"train", "test"};
|
watchs.put("train", trainMat);
|
||||||
|
watchs.put("test", testMat);
|
||||||
|
|
||||||
//train a booster
|
//train a booster
|
||||||
int round = 4;
|
int round = 4;
|
||||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||||
|
|
||||||
float[][] predicts = booster.predict(testMat);
|
float[][] predicts = booster.predict(testMat);
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import org.dmlc.xgboost4j.util.Params;
|
|||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
|
||||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||||
|
import org.dmlc.xgboost4j.util.WatchList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* predict first ntree
|
* predict first ntree
|
||||||
@ -35,20 +36,21 @@ public class PredictFirstNtree {
|
|||||||
//specify parameters
|
//specify parameters
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "2");
|
put("max_depth", 2);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//specify evaluate datasets and evaluate names
|
//specify watchList
|
||||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
WatchList watchs = new WatchList();
|
||||||
String[] evalNames = new String[] {"train", "test"};
|
watchs.put("train", trainMat);
|
||||||
|
watchs.put("test", testMat);
|
||||||
|
|
||||||
//train a booster
|
//train a booster
|
||||||
int round = 3;
|
int round = 3;
|
||||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||||
|
|
||||||
//predict use 1 tree
|
//predict use 1 tree
|
||||||
float[][] predicts1 = booster.predict(testMat, false, 1);
|
float[][] predicts1 = booster.predict(testMat, false, 1);
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import org.dmlc.xgboost4j.Booster;
|
|||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.WatchList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* predict leaf indices
|
* predict leaf indices
|
||||||
@ -34,20 +35,21 @@ public class PredictLeafIndices {
|
|||||||
//specify parameters
|
//specify parameters
|
||||||
Params param = new Params() {
|
Params param = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", "1.0");
|
put("eta", 1.0);
|
||||||
put("max_depth", "2");
|
put("max_depth", 2);
|
||||||
put("silent", "1");
|
put("silent", 1);
|
||||||
put("objective", "binary:logistic");
|
put("objective", "binary:logistic");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//specify evaluate datasets and evaluate names
|
//specify watchList
|
||||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
WatchList watchs = new WatchList();
|
||||||
String[] evalNames = new String[] {"train", "test"};
|
watchs.put("train", trainMat);
|
||||||
|
watchs.put("test", testMat);
|
||||||
|
|
||||||
//train a booster
|
//train a booster
|
||||||
int round = 3;
|
int round = 3;
|
||||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||||
|
|
||||||
//predict using first 2 tree
|
//predict using first 2 tree
|
||||||
float[][] leafindex = booster.predict(testMat, 2, true);
|
float[][] leafindex = booster.predict(testMat, 2, true);
|
||||||
|
|||||||
@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory;
|
|||||||
|
|
||||||
import org.dmlc.xgboost4j.util.Initializer;
|
import org.dmlc.xgboost4j.util.Initializer;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.TransferUtil;
|
|
||||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||||
|
|
||||||
|
|
||||||
@ -85,7 +84,7 @@ public final class Booster {
|
|||||||
private void init(DMatrix[] dMatrixs) {
|
private void init(DMatrix[] dMatrixs) {
|
||||||
long[] handles = null;
|
long[] handles = null;
|
||||||
if(dMatrixs != null) {
|
if(dMatrixs != null) {
|
||||||
handles = TransferUtil.dMatrixs2handles(dMatrixs);
|
handles = dMatrixs2handles(dMatrixs);
|
||||||
}
|
}
|
||||||
handle = XgboostJNI.XGBoosterCreate(handles);
|
handle = XgboostJNI.XGBoosterCreate(handles);
|
||||||
}
|
}
|
||||||
@ -105,8 +104,8 @@ public final class Booster {
|
|||||||
*/
|
*/
|
||||||
public void setParams(Params params) {
|
public void setParams(Params params) {
|
||||||
if(params!=null) {
|
if(params!=null) {
|
||||||
for(Map.Entry<String, String> entry : params) {
|
for(Map.Entry<String, Object> entry : params) {
|
||||||
setParam(entry.getKey(), entry.getValue());
|
setParam(entry.getKey(), entry.getValue().toString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,7 +153,7 @@ public final class Booster {
|
|||||||
* @return eval information
|
* @return eval information
|
||||||
*/
|
*/
|
||||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
|
||||||
long[] handles = TransferUtil.dMatrixs2handles(evalMatrixs);
|
long[] handles = dMatrixs2handles(evalMatrixs);
|
||||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
|
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
|
||||||
return evalInfo;
|
return evalInfo;
|
||||||
}
|
}
|
||||||
@ -424,6 +423,19 @@ public final class Booster {
|
|||||||
return featureScore;
|
return featureScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* transfer DMatrix array to handle array (used for native functions)
|
||||||
|
* @param dmatrixs
|
||||||
|
* @return handle array for input dmatrixs
|
||||||
|
*/
|
||||||
|
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
||||||
|
long[] handles = new long[dmatrixs.length];
|
||||||
|
for(int i=0; i<dmatrixs.length; i++) {
|
||||||
|
handles[i] = dmatrixs[i].getHandle();
|
||||||
|
}
|
||||||
|
return handles;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void finalize() {
|
protected void finalize() {
|
||||||
delete();
|
delete();
|
||||||
|
|||||||
@ -19,7 +19,6 @@ import java.io.IOException;
|
|||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
import org.dmlc.xgboost4j.util.Initializer;
|
import org.dmlc.xgboost4j.util.Initializer;
|
||||||
import org.dmlc.xgboost4j.util.TransferUtil;
|
|
||||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -126,7 +125,7 @@ public class DMatrix {
|
|||||||
* @param baseMargin
|
* @param baseMargin
|
||||||
*/
|
*/
|
||||||
public void setBaseMargin(float[][] baseMargin) {
|
public void setBaseMargin(float[][] baseMargin) {
|
||||||
float[] flattenMargin = TransferUtil.flatten(baseMargin);
|
float[] flattenMargin = flatten(baseMargin);
|
||||||
setBaseMargin(flattenMargin);
|
setBaseMargin(flattenMargin);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,6 +202,24 @@ public class DMatrix {
|
|||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* flatten a mat to array
|
||||||
|
* @param mat
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private static float[] flatten(float[][] mat) {
|
||||||
|
int size = 0;
|
||||||
|
for (float[] array : mat) size += array.length;
|
||||||
|
float[] result = new float[size];
|
||||||
|
int pos = 0;
|
||||||
|
for (float[] ar : mat) {
|
||||||
|
System.arraycopy(ar, 0, result, pos, ar.length);
|
||||||
|
pos += ar.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void finalize() {
|
protected void finalize() {
|
||||||
delete();
|
delete();
|
||||||
|
|||||||
@ -28,7 +28,6 @@ public class CVPack {
|
|||||||
DMatrix dtrain;
|
DMatrix dtrain;
|
||||||
DMatrix dtest;
|
DMatrix dtest;
|
||||||
DMatrix[] dmats;
|
DMatrix[] dmats;
|
||||||
long[] dataArray;
|
|
||||||
String[] names;
|
String[] names;
|
||||||
Booster booster;
|
Booster booster;
|
||||||
|
|
||||||
@ -41,7 +40,6 @@ public class CVPack {
|
|||||||
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
|
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
|
||||||
dmats = new DMatrix[] {dtrain, dtest};
|
dmats = new DMatrix[] {dtrain, dtest};
|
||||||
booster = new Booster(params, dmats);
|
booster = new Booster(params, dmats);
|
||||||
dataArray = TransferUtil.dMatrixs2handles(dmats);
|
|
||||||
names = new String[] {"train", "test"};
|
names = new String[] {"train", "test"};
|
||||||
this.dtrain = dtrain;
|
this.dtrain = dtrain;
|
||||||
this.dtest = dtest;
|
this.dtest = dtest;
|
||||||
@ -70,7 +68,7 @@ public class CVPack {
|
|||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public String eval(int iter) {
|
public String eval(int iter) {
|
||||||
return booster.evalSet(dataArray, names, iter);
|
return booster.evalSet(dmats, names, iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -43,7 +43,7 @@ public class Initializer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* load native library, this method will first try to load library from java.library.path, then try to load from library in jar package.
|
* load native library, this method will first try to load library from java.library.path, then try to load library in jar package.
|
||||||
* @param libName
|
* @param libName
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -1,7 +1,17 @@
|
|||||||
/*
|
/*
|
||||||
* To change this license header, choose License Headers in Project Properties.
|
Copyright (c) 2014 by Contributors
|
||||||
* To change this template file, choose Tools | Templates
|
|
||||||
* and open the template in the editor.
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.util;
|
package org.dmlc.xgboost4j.util;
|
||||||
|
|
||||||
|
|||||||
@ -26,29 +26,29 @@ import java.util.AbstractMap;
|
|||||||
* a util class for handle params
|
* a util class for handle params
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class Params implements Iterable<Entry<String, String>>{
|
public class Params implements Iterable<Entry<String, Object>>{
|
||||||
List<Entry<String, String>> params = new ArrayList<>();
|
List<Entry<String, Object>> params = new ArrayList<>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* put param key-value pair
|
* put param key-value pair
|
||||||
* @param key
|
* @param key
|
||||||
* @param value
|
* @param value
|
||||||
*/
|
*/
|
||||||
public void put(String key, String value) {
|
public void put(String key, Object value) {
|
||||||
params.add(new AbstractMap.SimpleEntry<>(key, value));
|
params.add(new AbstractMap.SimpleEntry<>(key, value));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString(){
|
public String toString(){
|
||||||
String paramsInfo = "";
|
String paramsInfo = "";
|
||||||
for(Entry<String, String> param : params) {
|
for(Entry<String, Object> param : params) {
|
||||||
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
|
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
|
||||||
}
|
}
|
||||||
return paramsInfo;
|
return paramsInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Iterator<Entry<String, String>> iterator() {
|
public Iterator<Entry<String, Object>> iterator() {
|
||||||
return params.iterator();
|
return params.iterator();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import java.util.Collections;
|
|||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
import org.dmlc.xgboost4j.IEvaluation;
|
import org.dmlc.xgboost4j.IEvaluation;
|
||||||
@ -40,14 +41,26 @@ public class Trainer {
|
|||||||
* @param params Booster params.
|
* @param params Booster params.
|
||||||
* @param dtrain Data to be trained.
|
* @param dtrain Data to be trained.
|
||||||
* @param round Number of boosting iterations.
|
* @param round Number of boosting iterations.
|
||||||
* @param evalMats Data to be evaluated (may include dtrain)
|
* @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set.
|
||||||
* @param evalNames name of data (used for evaluation info)
|
|
||||||
* @param obj customized objective (set to null if not used)
|
* @param obj customized objective (set to null if not used)
|
||||||
* @param eval customized evaluation (set to null if not used)
|
* @param eval customized evaluation (set to null if not used)
|
||||||
* @return trained booster
|
* @return trained booster
|
||||||
*/
|
*/
|
||||||
public static Booster train(Params params, DMatrix dtrain, int round,
|
public static Booster train(Params params, DMatrix dtrain, int round,
|
||||||
DMatrix[] evalMats, String[] evalNames, IObjective obj, IEvaluation eval) {
|
WatchList watchs, IObjective obj, IEvaluation eval) {
|
||||||
|
|
||||||
|
//collect eval matrixs
|
||||||
|
int len = watchs.size();
|
||||||
|
int i = 0;
|
||||||
|
String[] evalNames = new String[len];
|
||||||
|
DMatrix[] evalMats = new DMatrix[len];
|
||||||
|
|
||||||
|
for(Entry<String, DMatrix> evalEntry : watchs) {
|
||||||
|
evalNames[i] = evalEntry.getKey();
|
||||||
|
evalMats[i] = evalEntry.getValue();
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
//collect all data matrixs
|
//collect all data matrixs
|
||||||
DMatrix[] allMats;
|
DMatrix[] allMats;
|
||||||
if(evalMats!=null && evalMats.length>0) {
|
if(evalMats!=null && evalMats.length>0) {
|
||||||
@ -63,16 +76,6 @@ public class Trainer {
|
|||||||
//initialize booster
|
//initialize booster
|
||||||
Booster booster = new Booster(params, allMats);
|
Booster booster = new Booster(params, allMats);
|
||||||
|
|
||||||
//used for evaluation
|
|
||||||
long[] dataArray = null;
|
|
||||||
String[] names = null;
|
|
||||||
|
|
||||||
if(dataArray==null || names==null) {
|
|
||||||
//prepare data for evaluation
|
|
||||||
dataArray = TransferUtil.dMatrixs2handles(evalMats);
|
|
||||||
names = evalNames;
|
|
||||||
}
|
|
||||||
|
|
||||||
//begin to train
|
//begin to train
|
||||||
for(int iter=0; iter<round; iter++) {
|
for(int iter=0; iter<round; iter++) {
|
||||||
if(obj != null) {
|
if(obj != null) {
|
||||||
@ -88,7 +91,7 @@ public class Trainer {
|
|||||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
|
evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
evalInfo = booster.evalSet(dataArray, names, iter);
|
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||||
}
|
}
|
||||||
logger.info(evalInfo);
|
logger.info(evalInfo);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,55 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
package org.dmlc.xgboost4j.util;
|
|
||||||
|
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @author hzx
|
|
||||||
*/
|
|
||||||
public class TransferUtil {
|
|
||||||
/**
|
|
||||||
* transfer DMatrix array to handle array (used for native functions)
|
|
||||||
* @param dmatrixs
|
|
||||||
* @return handle array for input dmatrixs
|
|
||||||
*/
|
|
||||||
public static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
|
||||||
long[] handles = new long[dmatrixs.length];
|
|
||||||
for(int i=0; i<dmatrixs.length; i++) {
|
|
||||||
handles[i] = dmatrixs[i].getHandle();
|
|
||||||
}
|
|
||||||
return handles;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* flatten a mat to array
|
|
||||||
* @param mat
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static float[] flatten(float[][] mat) {
|
|
||||||
int size = 0;
|
|
||||||
for (float[] array : mat) size += array.length;
|
|
||||||
float[] result = new float[size];
|
|
||||||
int pos = 0;
|
|
||||||
for (float[] ar : mat) {
|
|
||||||
System.arraycopy(ar, 0, result, pos, ar.length);
|
|
||||||
pos += ar.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.dmlc.xgboost4j.util;
|
||||||
|
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* class to handle evaluation dmatrix
|
||||||
|
* @author hzx
|
||||||
|
*/
|
||||||
|
public class WatchList implements Iterable<Entry<String, DMatrix> >{
|
||||||
|
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* put eval dmatrix and it's name
|
||||||
|
* @param name
|
||||||
|
* @param dmat
|
||||||
|
*/
|
||||||
|
public void put(String name, DMatrix dmat) {
|
||||||
|
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
|
||||||
|
}
|
||||||
|
|
||||||
|
public int size() {
|
||||||
|
return watchList.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<Entry<String, DMatrix>> iterator() {
|
||||||
|
return watchList.iterator();
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1 +0,0 @@
|
|||||||
please put native library in this package.
|
|
||||||
Loading…
x
Reference in New Issue
Block a user