make some fix
This commit is contained in:
parent
f91a098770
commit
c110111f52
@ -17,11 +17,11 @@ core of this wrapper is two classes:
|
||||
|
||||
## build native library
|
||||
|
||||
for windows: open the xgboost.sln in windows folder, you will found the xgboostjavawrapper project, you should do the following steps to build wrapper library:
|
||||
for windows: open the xgboost.sln in "../windows" folder, you will found the xgboostjavawrapper project, you should do the following steps to build wrapper library:
|
||||
* Select x64/win32 and Release in build
|
||||
* (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboostjavawrapper project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path.
|
||||
* rebuild all
|
||||
* move the dll "xgboostjavawrapper.dll" to "xgboost4j/src/main/resources/lib/"(you may need to create this folder if necessary.)
|
||||
* double click "create_wrap.bat" to set library to proper place
|
||||
|
||||
for linux:
|
||||
* make sure you have installed jdk and `JAVA_HOME` has been setted properly
|
||||
|
||||
20
java/create_wrap.bat
Normal file
20
java/create_wrap.bat
Normal file
@ -0,0 +1,20 @@
|
||||
echo "move native library"
|
||||
set libsource=..\windows\x64\Release\xgboostjavawrapper.dll
|
||||
|
||||
if not exist %libsource% (
|
||||
goto end
|
||||
)
|
||||
|
||||
set libfolder=xgboost4j\src\main\resources\lib
|
||||
set libpath=%libfolder%\xgboostjavawrapper.dll
|
||||
if not exist %libfolder% (mkdir %libfolder%)
|
||||
if exist %libpath% (del %libpath%)
|
||||
move %libsource% %libfolder%
|
||||
echo complete
|
||||
pause
|
||||
exit
|
||||
|
||||
:end
|
||||
echo "source library not found, please build it first from ..\windows\xgboost.sln"
|
||||
pause
|
||||
exit
|
||||
@ -6,7 +6,7 @@ echo "move native lib"
|
||||
|
||||
libPath="xgboost4j/src/main/resources/lib"
|
||||
if [ ! -d "$libPath" ]; then
|
||||
mkdir "$libPath"
|
||||
mkdir -p "$libPath"
|
||||
fi
|
||||
|
||||
rm -f xgboost4j/src/main/resources/lib/libxgboostjavawrapper.so
|
||||
|
||||
@ -82,9 +82,9 @@ import org.dmlc.xgboost4j.util.Params;
|
||||
```java
|
||||
Params params = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("eval_metric", "logloss");
|
||||
}
|
||||
@ -94,9 +94,9 @@ Params params = new Params() {
|
||||
```java
|
||||
Params params = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("eval_metric", "logloss");
|
||||
put("eval_metric", "error");
|
||||
@ -110,16 +110,19 @@ With parameters and data, you are able to train a booster model.
|
||||
```java
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
```
|
||||
|
||||
* Training
|
||||
```java
|
||||
DMatrix trainMat = new DMatrix("train.svm.txt");
|
||||
DMatrix validMat = new DMatrix("valid.svm.txt");
|
||||
DMatrix[] evalMats = new DMatrix[] {trainMat, validMat};
|
||||
String[] evalNames = new String[] {"train", "valid"};
|
||||
//specifiy a watchList to see the performance
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
int round = 2;
|
||||
Booster booster = Trainer.train(params, trainMat, round, evalMats, evalNames, null, null);
|
||||
Booster booster = Trainer.train(params, trainMat, round, watchs, null, null);
|
||||
```
|
||||
|
||||
* Saving model
|
||||
@ -139,8 +142,8 @@ booster.dumpModel("modelInfo.txt", "featureMap.txt", false)
|
||||
```java
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("silent", "1");
|
||||
put("nthread", "6");
|
||||
put("silent", 1);
|
||||
put("nthread", 6);
|
||||
}
|
||||
};
|
||||
Booster booster = new Booster(param, "model.bin");
|
||||
|
||||
@ -24,6 +24,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* a simple example of java wrapper for xgboost
|
||||
@ -53,22 +54,23 @@ public class BasicWalkThrough {
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||
|
||||
//predict
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
@ -107,8 +109,11 @@ public class BasicWalkThrough {
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR);
|
||||
trainMat2.setLabel(spData.labels);
|
||||
|
||||
dmats = new DMatrix[] {trainMat2, testMat};
|
||||
Booster booster3 = Trainer.train(param, trainMat2, round, dmats, evalNames, null, null);
|
||||
//specify watchList
|
||||
WatchList watchs2 = new WatchList();
|
||||
watchs2.put("train", trainMat2);
|
||||
watchs2.put("test", testMat);
|
||||
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
|
||||
float[][] predicts3 = booster3.predict(testMat2);
|
||||
|
||||
//check predicts
|
||||
|
||||
@ -19,6 +19,7 @@ import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* example for start from a initial base prediction
|
||||
@ -35,19 +36,20 @@ public class BoostFromPrediction {
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
|
||||
//train xgboost for 1 round
|
||||
Booster booster = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
|
||||
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);
|
||||
|
||||
float[][] trainPred = booster.predict(trainMat, true);
|
||||
float[][] testPred = booster.predict(testMat, true);
|
||||
@ -56,6 +58,6 @@ public class BoostFromPrediction {
|
||||
testMat.setBaseMargin(testPred);
|
||||
|
||||
System.out.println("result of running from initial prediction");
|
||||
Booster booster2 = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
|
||||
Booster booster2 = Trainer.train(param, trainMat, 1, watchs, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,12 +32,12 @@ public class CrossValidation {
|
||||
//set params
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "3");
|
||||
put("silent", "1");
|
||||
put("nthread", "6");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("nthread", 6);
|
||||
put("objective", "binary:logistic");
|
||||
put("gamma", "1.0");
|
||||
put("gamma", 1.0);
|
||||
put("eval_metric", "error");
|
||||
}
|
||||
};
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
@ -24,6 +23,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.IObjective;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* an example user define objective and eval
|
||||
@ -130,18 +130,19 @@ public class CustomObjective {
|
||||
//set params
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
}
|
||||
};
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//set evaluation data
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "eval"};
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
|
||||
//user define obj and eval
|
||||
IObjective obj = new LogRegObj();
|
||||
@ -149,6 +150,6 @@ public class CustomObjective {
|
||||
|
||||
//train a booster
|
||||
System.out.println("begin to train the booster model");
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, obj, eval);
|
||||
Booster booster = Trainer.train(param, trainMat, round, watchs, obj, eval);
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* simple example for using external memory version
|
||||
@ -35,25 +36,26 @@ public class ExternalMemory {
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//performance notice: set nthread to be the number of your real cpu
|
||||
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
|
||||
//param.put("nthread", "num_real_cpu");
|
||||
//param.put("nthread", num_real_cpu);
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
@ -39,8 +40,8 @@ public class GeneralizedLinearModel {
|
||||
//you can also set lambda_bias which is L2 regularizer on the bias term
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("alpha", "0.0001");
|
||||
put("silent", "1");
|
||||
put("alpha", 0.0001);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
put("booster", "gblinear");
|
||||
}
|
||||
@ -52,13 +53,14 @@ public class GeneralizedLinearModel {
|
||||
//param.put("eta", "0.5");
|
||||
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
|
||||
//train a booster
|
||||
int round = 4;
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* predict first ntree
|
||||
@ -35,20 +36,21 @@ public class PredictFirstNtree {
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||
|
||||
//predict use 1 tree
|
||||
float[][] predicts1 = booster.predict(testMat, false, 1);
|
||||
|
||||
@ -20,6 +20,7 @@ import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* predict leaf indices
|
||||
@ -34,20 +35,21 @@ public class PredictLeafIndices {
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
{
|
||||
put("eta", "1.0");
|
||||
put("max_depth", "2");
|
||||
put("silent", "1");
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//specify evaluate datasets and evaluate names
|
||||
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
|
||||
String[] evalNames = new String[] {"train", "test"};
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
|
||||
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||
|
||||
//predict using first 2 tree
|
||||
float[][] leafindex = booster.predict(testMat, 2, true);
|
||||
|
||||
@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.TransferUtil;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
|
||||
@ -85,7 +84,7 @@ public final class Booster {
|
||||
private void init(DMatrix[] dMatrixs) {
|
||||
long[] handles = null;
|
||||
if(dMatrixs != null) {
|
||||
handles = TransferUtil.dMatrixs2handles(dMatrixs);
|
||||
handles = dMatrixs2handles(dMatrixs);
|
||||
}
|
||||
handle = XgboostJNI.XGBoosterCreate(handles);
|
||||
}
|
||||
@ -105,8 +104,8 @@ public final class Booster {
|
||||
*/
|
||||
public void setParams(Params params) {
|
||||
if(params!=null) {
|
||||
for(Map.Entry<String, String> entry : params) {
|
||||
setParam(entry.getKey(), entry.getValue());
|
||||
for(Map.Entry<String, Object> entry : params) {
|
||||
setParam(entry.getKey(), entry.getValue().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -154,7 +153,7 @@ public final class Booster {
|
||||
* @return eval information
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
|
||||
long[] handles = TransferUtil.dMatrixs2handles(evalMatrixs);
|
||||
long[] handles = dMatrixs2handles(evalMatrixs);
|
||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
|
||||
return evalInfo;
|
||||
}
|
||||
@ -424,6 +423,19 @@ public final class Booster {
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for(int i=0; i<dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
}
|
||||
return handles;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
delete();
|
||||
|
||||
@ -19,7 +19,6 @@ import java.io.IOException;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.util.TransferUtil;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
/**
|
||||
@ -126,7 +125,7 @@ public class DMatrix {
|
||||
* @param baseMargin
|
||||
*/
|
||||
public void setBaseMargin(float[][] baseMargin) {
|
||||
float[] flattenMargin = TransferUtil.flatten(baseMargin);
|
||||
float[] flattenMargin = flatten(baseMargin);
|
||||
setBaseMargin(flattenMargin);
|
||||
}
|
||||
|
||||
@ -203,6 +202,24 @@ public class DMatrix {
|
||||
return handle;
|
||||
}
|
||||
|
||||
/**
|
||||
* flatten a mat to array
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
private static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
for (float[] array : mat) size += array.length;
|
||||
float[] result = new float[size];
|
||||
int pos = 0;
|
||||
for (float[] ar : mat) {
|
||||
System.arraycopy(ar, 0, result, pos, ar.length);
|
||||
pos += ar.length;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
delete();
|
||||
|
||||
@ -28,7 +28,6 @@ public class CVPack {
|
||||
DMatrix dtrain;
|
||||
DMatrix dtest;
|
||||
DMatrix[] dmats;
|
||||
long[] dataArray;
|
||||
String[] names;
|
||||
Booster booster;
|
||||
|
||||
@ -41,7 +40,6 @@ public class CVPack {
|
||||
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
|
||||
dmats = new DMatrix[] {dtrain, dtest};
|
||||
booster = new Booster(params, dmats);
|
||||
dataArray = TransferUtil.dMatrixs2handles(dmats);
|
||||
names = new String[] {"train", "test"};
|
||||
this.dtrain = dtrain;
|
||||
this.dtest = dtest;
|
||||
@ -70,7 +68,7 @@ public class CVPack {
|
||||
* @return
|
||||
*/
|
||||
public String eval(int iter) {
|
||||
return booster.evalSet(dataArray, names, iter);
|
||||
return booster.evalSet(dmats, names, iter);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -43,7 +43,7 @@ public class Initializer {
|
||||
}
|
||||
|
||||
/**
|
||||
* load native library, this method will first try to load library from java.library.path, then try to load from library in jar package.
|
||||
* load native library, this method will first try to load library from java.library.path, then try to load library in jar package.
|
||||
* @param libName
|
||||
* @throws IOException
|
||||
*/
|
||||
|
||||
@ -1,7 +1,17 @@
|
||||
/*
|
||||
* To change this license header, choose License Headers in Project Properties.
|
||||
* To change this template file, choose Tools | Templates
|
||||
* and open the template in the editor.
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
|
||||
@ -26,29 +26,29 @@ import java.util.AbstractMap;
|
||||
* a util class for handle params
|
||||
* @author hzx
|
||||
*/
|
||||
public class Params implements Iterable<Entry<String, String>>{
|
||||
List<Entry<String, String>> params = new ArrayList<>();
|
||||
public class Params implements Iterable<Entry<String, Object>>{
|
||||
List<Entry<String, Object>> params = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* put param key-value pair
|
||||
* @param key
|
||||
* @param value
|
||||
*/
|
||||
public void put(String key, String value) {
|
||||
public void put(String key, Object value) {
|
||||
params.add(new AbstractMap.SimpleEntry<>(key, value));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
String paramsInfo = "";
|
||||
for(Entry<String, String> param : params) {
|
||||
for(Entry<String, Object> param : params) {
|
||||
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
|
||||
}
|
||||
return paramsInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Entry<String, String>> iterator() {
|
||||
public Iterator<Entry<String, Object>> iterator() {
|
||||
return params.iterator();
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
@ -40,14 +41,26 @@ public class Trainer {
|
||||
* @param params Booster params.
|
||||
* @param dtrain Data to be trained.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param evalMats Data to be evaluated (may include dtrain)
|
||||
* @param evalNames name of data (used for evaluation info)
|
||||
* @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set.
|
||||
* @param obj customized objective (set to null if not used)
|
||||
* @param eval customized evaluation (set to null if not used)
|
||||
* @return trained booster
|
||||
*/
|
||||
public static Booster train(Params params, DMatrix dtrain, int round,
|
||||
DMatrix[] evalMats, String[] evalNames, IObjective obj, IEvaluation eval) {
|
||||
public static Booster train(Params params, DMatrix dtrain, int round,
|
||||
WatchList watchs, IObjective obj, IEvaluation eval) {
|
||||
|
||||
//collect eval matrixs
|
||||
int len = watchs.size();
|
||||
int i = 0;
|
||||
String[] evalNames = new String[len];
|
||||
DMatrix[] evalMats = new DMatrix[len];
|
||||
|
||||
for(Entry<String, DMatrix> evalEntry : watchs) {
|
||||
evalNames[i] = evalEntry.getKey();
|
||||
evalMats[i] = evalEntry.getValue();
|
||||
i++;
|
||||
}
|
||||
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
if(evalMats!=null && evalMats.length>0) {
|
||||
@ -63,16 +76,6 @@ public class Trainer {
|
||||
//initialize booster
|
||||
Booster booster = new Booster(params, allMats);
|
||||
|
||||
//used for evaluation
|
||||
long[] dataArray = null;
|
||||
String[] names = null;
|
||||
|
||||
if(dataArray==null || names==null) {
|
||||
//prepare data for evaluation
|
||||
dataArray = TransferUtil.dMatrixs2handles(evalMats);
|
||||
names = evalNames;
|
||||
}
|
||||
|
||||
//begin to train
|
||||
for(int iter=0; iter<round; iter++) {
|
||||
if(obj != null) {
|
||||
@ -88,7 +91,7 @@ public class Trainer {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
|
||||
}
|
||||
else {
|
||||
evalInfo = booster.evalSet(dataArray, names, iter);
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||
}
|
||||
logger.info(evalInfo);
|
||||
}
|
||||
|
||||
@ -1,55 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class TransferUtil {
|
||||
/**
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
public static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for(int i=0; i<dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
}
|
||||
return handles;
|
||||
}
|
||||
|
||||
/**
|
||||
* flatten a mat to array
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
public static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
for (float[] array : mat) size += array.length;
|
||||
float[] result = new float[size];
|
||||
int pos = 0;
|
||||
for (float[] ar : mat) {
|
||||
System.arraycopy(ar, 0, result, pos, ar.length);
|
||||
pos += ar.length;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,49 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
|
||||
/**
|
||||
* class to handle evaluation dmatrix
|
||||
* @author hzx
|
||||
*/
|
||||
public class WatchList implements Iterable<Entry<String, DMatrix> >{
|
||||
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* put eval dmatrix and it's name
|
||||
* @param name
|
||||
* @param dmat
|
||||
*/
|
||||
public void put(String name, DMatrix dmat) {
|
||||
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return watchList.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Entry<String, DMatrix>> iterator() {
|
||||
return watchList.iterator();
|
||||
}
|
||||
}
|
||||
@ -1 +0,0 @@
|
||||
please put native library in this package.
|
||||
Loading…
x
Reference in New Issue
Block a user