From c110111f521ae3b95096e4b73cf3233a9e21de1b Mon Sep 17 00:00:00 2001 From: yanqingmen Date: Wed, 10 Jun 2015 20:09:49 -0700 Subject: [PATCH] make some fix --- java/README.md | 4 +- java/create_wrap.bat | 20 +++++++ java/create_wrap.sh | 2 +- java/doc/xgboost4j.md | 25 +++++---- .../dmlc/xgboost4j/demo/BasicWalkThrough.java | 23 +++++--- .../xgboost4j/demo/BoostFromPrediction.java | 18 +++--- .../dmlc/xgboost4j/demo/CrossValidation.java | 10 ++-- .../dmlc/xgboost4j/demo/CustomObjective.java | 17 +++--- .../dmlc/xgboost4j/demo/ExternalMemory.java | 18 +++--- .../demo/GeneralizedLinearModel.java | 14 +++-- .../xgboost4j/demo/PredictFirstNtree.java | 16 +++--- .../xgboost4j/demo/PredictLeafIndices.java | 16 +++--- .../main/java/org/dmlc/xgboost4j/Booster.java | 22 ++++++-- .../main/java/org/dmlc/xgboost4j/DMatrix.java | 21 ++++++- .../java/org/dmlc/xgboost4j/util/CVPack.java | 4 +- .../org/dmlc/xgboost4j/util/Initializer.java | 2 +- .../org/dmlc/xgboost4j/util/NativeUtils.java | 16 +++++- .../java/org/dmlc/xgboost4j/util/Params.java | 10 ++-- .../java/org/dmlc/xgboost4j/util/Trainer.java | 33 ++++++----- .../org/dmlc/xgboost4j/util/TransferUtil.java | 55 ------------------- .../org/dmlc/xgboost4j/util/WatchList.java | 49 +++++++++++++++++ .../src/main/resources/lib/README.md | 1 - 22 files changed, 234 insertions(+), 162 deletions(-) create mode 100644 java/create_wrap.bat delete mode 100644 java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/TransferUtil.java create mode 100644 java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/WatchList.java delete mode 100644 java/xgboost4j/src/main/resources/lib/README.md diff --git a/java/README.md b/java/README.md index 161d594d8..12cbb4582 100644 --- a/java/README.md +++ b/java/README.md @@ -17,11 +17,11 @@ core of this wrapper is two classes: ## build native library -for windows: open the xgboost.sln in windows folder, you will found the xgboostjavawrapper project, you should do the following steps to build wrapper library: +for windows: open the xgboost.sln in "../windows" folder, you will found the xgboostjavawrapper project, you should do the following steps to build wrapper library: * Select x64/win32 and Release in build * (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboostjavawrapper project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path. * rebuild all - * move the dll "xgboostjavawrapper.dll" to "xgboost4j/src/main/resources/lib/"(you may need to create this folder if necessary.) + * double click "create_wrap.bat" to set library to proper place for linux: * make sure you have installed jdk and `JAVA_HOME` has been setted properly diff --git a/java/create_wrap.bat b/java/create_wrap.bat new file mode 100644 index 000000000..e7f8603cd --- /dev/null +++ b/java/create_wrap.bat @@ -0,0 +1,20 @@ +echo "move native library" +set libsource=..\windows\x64\Release\xgboostjavawrapper.dll + +if not exist %libsource% ( +goto end +) + +set libfolder=xgboost4j\src\main\resources\lib +set libpath=%libfolder%\xgboostjavawrapper.dll +if not exist %libfolder% (mkdir %libfolder%) +if exist %libpath% (del %libpath%) +move %libsource% %libfolder% +echo complete +pause +exit + +:end + echo "source library not found, please build it first from ..\windows\xgboost.sln" + pause + exit \ No newline at end of file diff --git a/java/create_wrap.sh b/java/create_wrap.sh index 08b3f6792..d66e4dbd4 100755 --- a/java/create_wrap.sh +++ b/java/create_wrap.sh @@ -6,7 +6,7 @@ echo "move native lib" libPath="xgboost4j/src/main/resources/lib" if [ ! -d "$libPath" ]; then - mkdir "$libPath" + mkdir -p "$libPath" fi rm -f xgboost4j/src/main/resources/lib/libxgboostjavawrapper.so diff --git a/java/doc/xgboost4j.md b/java/doc/xgboost4j.md index f23ff509a..b383e9a04 100644 --- a/java/doc/xgboost4j.md +++ b/java/doc/xgboost4j.md @@ -82,9 +82,9 @@ import org.dmlc.xgboost4j.util.Params; ```java Params params = new Params() { { - put("eta", "1.0"); - put("max_depth", "2"); - put("silent", "1"); + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); put("objective", "binary:logistic"); put("eval_metric", "logloss"); } @@ -94,9 +94,9 @@ Params params = new Params() { ```java Params params = new Params() { { - put("eta", "1.0"); - put("max_depth", "2"); - put("silent", "1"); + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); put("objective", "binary:logistic"); put("eval_metric", "logloss"); put("eval_metric", "error"); @@ -110,16 +110,19 @@ With parameters and data, you are able to train a booster model. ```java import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.WatchList; ``` * Training ```java DMatrix trainMat = new DMatrix("train.svm.txt"); DMatrix validMat = new DMatrix("valid.svm.txt"); -DMatrix[] evalMats = new DMatrix[] {trainMat, validMat}; -String[] evalNames = new String[] {"train", "valid"}; +//specifiy a watchList to see the performance +WatchList watchs = new WatchList(); +watchs.put("train", trainMat); +watchs.put("test", testMat); int round = 2; -Booster booster = Trainer.train(params, trainMat, round, evalMats, evalNames, null, null); +Booster booster = Trainer.train(params, trainMat, round, watchs, null, null); ``` * Saving model @@ -139,8 +142,8 @@ booster.dumpModel("modelInfo.txt", "featureMap.txt", false) ```java Params param = new Params() { { - put("silent", "1"); - put("nthread", "6"); + put("silent", 1); + put("nthread", 6); } }; Booster booster = new Booster(param, "model.bin"); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java index 778d05a4d..0a80ae314 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java @@ -24,6 +24,7 @@ import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.demo.util.DataLoader; import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.WatchList; /** * a simple example of java wrapper for xgboost @@ -53,22 +54,23 @@ public class BasicWalkThrough { //specify parameters Params param = new Params() { { - put("eta", "1.0"); - put("max_depth", "2"); - put("silent", "1"); + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); put("objective", "binary:logistic"); } }; - //specify evaluate datasets and evaluate names - DMatrix[] dmats = new DMatrix[] {trainMat, testMat}; - String[] evalNames = new String[] {"train", "test"}; + //specify watchList + WatchList watchs = new WatchList(); + watchs.put("train", trainMat); + watchs.put("test", testMat); //set round int round = 2; //train a boost model - Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null); + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); //predict float[][] predicts = booster.predict(testMat); @@ -107,8 +109,11 @@ public class BasicWalkThrough { DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR); trainMat2.setLabel(spData.labels); - dmats = new DMatrix[] {trainMat2, testMat}; - Booster booster3 = Trainer.train(param, trainMat2, round, dmats, evalNames, null, null); + //specify watchList + WatchList watchs2 = new WatchList(); + watchs2.put("train", trainMat2); + watchs2.put("test", testMat); + Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null); float[][] predicts3 = booster3.predict(testMat2); //check predicts diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java index ed029a6a1..54c8c1bee 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java @@ -19,6 +19,7 @@ import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.WatchList; /** * example for start from a initial base prediction @@ -35,19 +36,20 @@ public class BoostFromPrediction { //specify parameters Params param = new Params() { { - put("eta", "1.0"); - put("max_depth", "2"); - put("silent", "1"); + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); put("objective", "binary:logistic"); } }; - //specify evaluate datasets and evaluate names - DMatrix[] dmats = new DMatrix[] {trainMat, testMat}; - String[] evalNames = new String[] {"train", "test"}; + //specify watchList + WatchList watchs = new WatchList(); + watchs.put("train", trainMat); + watchs.put("test", testMat); //train xgboost for 1 round - Booster booster = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null); + Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null); float[][] trainPred = booster.predict(trainMat, true); float[][] testPred = booster.predict(testMat, true); @@ -56,6 +58,6 @@ public class BoostFromPrediction { testMat.setBaseMargin(testPred); System.out.println("result of running from initial prediction"); - Booster booster2 = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null); + Booster booster2 = Trainer.train(param, trainMat, 1, watchs, null, null); } } diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java index 754ae072c..793ffb61d 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java @@ -32,12 +32,12 @@ public class CrossValidation { //set params Params param = new Params() { { - put("eta", "1.0"); - put("max_depth", "3"); - put("silent", "1"); - put("nthread", "6"); + put("eta", 1.0); + put("max_depth", 3); + put("silent", 1); + put("nthread", 6); put("objective", "binary:logistic"); - put("gamma", "1.0"); + put("gamma", 1.0); put("eval_metric", "error"); } }; diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java index d0caaf53f..ed8c9a9a9 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java @@ -16,7 +16,6 @@ package org.dmlc.xgboost4j.demo; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.IEvaluation; @@ -24,6 +23,7 @@ import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.IObjective; import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.WatchList; /** * an example user define objective and eval @@ -130,18 +130,19 @@ public class CustomObjective { //set params Params param = new Params() { { - put("eta", "1.0"); - put("max_depth", "2"); - put("silent", "1"); + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); } }; //set round int round = 2; - //set evaluation data - DMatrix[] dmats = new DMatrix[] {trainMat, testMat}; - String[] evalNames = new String[] {"train", "eval"}; + //specify watchList + WatchList watchs = new WatchList(); + watchs.put("train", trainMat); + watchs.put("test", testMat); //user define obj and eval IObjective obj = new LogRegObj(); @@ -149,6 +150,6 @@ public class CustomObjective { //train a booster System.out.println("begin to train the booster model"); - Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, obj, eval); + Booster booster = Trainer.train(param, trainMat, round, watchs, obj, eval); } } diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java index 2912d43eb..698245bf1 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java @@ -19,6 +19,7 @@ import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.WatchList; /** * simple example for using external memory version @@ -35,25 +36,26 @@ public class ExternalMemory { //specify parameters Params param = new Params() { { - put("eta", "1.0"); - put("max_depth", "2"); - put("silent", "1"); + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); put("objective", "binary:logistic"); } }; //performance notice: set nthread to be the number of your real cpu //some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4 - //param.put("nthread", "num_real_cpu"); + //param.put("nthread", num_real_cpu); - //specify evaluate datasets and evaluate names - DMatrix[] dmats = new DMatrix[] {trainMat, testMat}; - String[] evalNames = new String[] {"train", "test"}; + //specify watchList + WatchList watchs = new WatchList(); + watchs.put("train", trainMat); + watchs.put("test", testMat); //set round int round = 2; //train a boost model - Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null); + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); } } diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java index 6bdc02ab5..a9b3ba8df 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java @@ -20,6 +20,7 @@ import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.demo.util.CustomEval; import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.WatchList; /** * this is an example of fit generalized linear model in xgboost @@ -39,8 +40,8 @@ public class GeneralizedLinearModel { //you can also set lambda_bias which is L2 regularizer on the bias term Params param = new Params() { { - put("alpha", "0.0001"); - put("silent", "1"); + put("alpha", 0.0001); + put("silent", 1); put("objective", "binary:logistic"); put("booster", "gblinear"); } @@ -52,13 +53,14 @@ public class GeneralizedLinearModel { //param.put("eta", "0.5"); - //specify evaluate datasets and evaluate names - DMatrix[] dmats = new DMatrix[] {trainMat, testMat}; - String[] evalNames = new String[] {"train", "test"}; + //specify watchList + WatchList watchs = new WatchList(); + watchs.put("train", trainMat); + watchs.put("test", testMat); //train a booster int round = 4; - Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null); + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); float[][] predicts = booster.predict(testMat); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java index 51604e8ec..bfcc04d06 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java @@ -21,6 +21,7 @@ import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.demo.util.CustomEval; +import org.dmlc.xgboost4j.util.WatchList; /** * predict first ntree @@ -35,20 +36,21 @@ public class PredictFirstNtree { //specify parameters Params param = new Params() { { - put("eta", "1.0"); - put("max_depth", "2"); - put("silent", "1"); + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); put("objective", "binary:logistic"); } }; - //specify evaluate datasets and evaluate names - DMatrix[] dmats = new DMatrix[] {trainMat, testMat}; - String[] evalNames = new String[] {"train", "test"}; + //specify watchList + WatchList watchs = new WatchList(); + watchs.put("train", trainMat); + watchs.put("test", testMat); //train a booster int round = 3; - Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null); + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); //predict use 1 tree float[][] predicts1 = booster.predict(testMat, false, 1); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java index ced309b03..5f1c2e5ac 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java @@ -20,6 +20,7 @@ import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.WatchList; /** * predict leaf indices @@ -34,20 +35,21 @@ public class PredictLeafIndices { //specify parameters Params param = new Params() { { - put("eta", "1.0"); - put("max_depth", "2"); - put("silent", "1"); + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); put("objective", "binary:logistic"); } }; - //specify evaluate datasets and evaluate names - DMatrix[] dmats = new DMatrix[] {trainMat, testMat}; - String[] evalNames = new String[] {"train", "test"}; + //specify watchList + WatchList watchs = new WatchList(); + watchs.put("train", trainMat); + watchs.put("test", testMat); //train a booster int round = 3; - Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null); + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); //predict using first 2 tree float[][] leafindex = booster.predict(testMat, 2, true); diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java index 91a2bd40b..3140b184e 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java @@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory; import org.dmlc.xgboost4j.util.Initializer; import org.dmlc.xgboost4j.util.Params; -import org.dmlc.xgboost4j.util.TransferUtil; import org.dmlc.xgboost4j.wrapper.XgboostJNI; @@ -85,7 +84,7 @@ public final class Booster { private void init(DMatrix[] dMatrixs) { long[] handles = null; if(dMatrixs != null) { - handles = TransferUtil.dMatrixs2handles(dMatrixs); + handles = dMatrixs2handles(dMatrixs); } handle = XgboostJNI.XGBoosterCreate(handles); } @@ -105,8 +104,8 @@ public final class Booster { */ public void setParams(Params params) { if(params!=null) { - for(Map.Entry entry : params) { - setParam(entry.getKey(), entry.getValue()); + for(Map.Entry entry : params) { + setParam(entry.getKey(), entry.getValue().toString()); } } } @@ -154,7 +153,7 @@ public final class Booster { * @return eval information */ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) { - long[] handles = TransferUtil.dMatrixs2handles(evalMatrixs); + long[] handles = dMatrixs2handles(evalMatrixs); String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames); return evalInfo; } @@ -424,6 +423,19 @@ public final class Booster { return featureScore; } + /** + * transfer DMatrix array to handle array (used for native functions) + * @param dmatrixs + * @return handle array for input dmatrixs + */ + private static long[] dMatrixs2handles(DMatrix[] dmatrixs) { + long[] handles = new long[dmatrixs.length]; + for(int i=0; i>{ - List> params = new ArrayList<>(); +public class Params implements Iterable>{ + List> params = new ArrayList<>(); /** * put param key-value pair * @param key * @param value */ - public void put(String key, String value) { + public void put(String key, Object value) { params.add(new AbstractMap.SimpleEntry<>(key, value)); } @Override public String toString(){ String paramsInfo = ""; - for(Entry param : params) { + for(Entry param : params) { paramsInfo += param.getKey() + ":" + param.getValue() + "\n"; } return paramsInfo; } @Override - public Iterator> iterator() { + public Iterator> iterator() { return params.iterator(); } } diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java index 76f5f58bc..a53437477 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java @@ -20,6 +20,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.dmlc.xgboost4j.IEvaluation; @@ -40,14 +41,26 @@ public class Trainer { * @param params Booster params. * @param dtrain Data to be trained. * @param round Number of boosting iterations. - * @param evalMats Data to be evaluated (may include dtrain) - * @param evalNames name of data (used for evaluation info) + * @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set. * @param obj customized objective (set to null if not used) * @param eval customized evaluation (set to null if not used) * @return trained booster */ - public static Booster train(Params params, DMatrix dtrain, int round, - DMatrix[] evalMats, String[] evalNames, IObjective obj, IEvaluation eval) { + public static Booster train(Params params, DMatrix dtrain, int round, + WatchList watchs, IObjective obj, IEvaluation eval) { + + //collect eval matrixs + int len = watchs.size(); + int i = 0; + String[] evalNames = new String[len]; + DMatrix[] evalMats = new DMatrix[len]; + + for(Entry evalEntry : watchs) { + evalNames[i] = evalEntry.getKey(); + evalMats[i] = evalEntry.getValue(); + i++; + } + //collect all data matrixs DMatrix[] allMats; if(evalMats!=null && evalMats.length>0) { @@ -63,16 +76,6 @@ public class Trainer { //initialize booster Booster booster = new Booster(params, allMats); - //used for evaluation - long[] dataArray = null; - String[] names = null; - - if(dataArray==null || names==null) { - //prepare data for evaluation - dataArray = TransferUtil.dMatrixs2handles(evalMats); - names = evalNames; - } - //begin to train for(int iter=0; iter >{ + List> watchList = new ArrayList<>(); + + /** + * put eval dmatrix and it's name + * @param name + * @param dmat + */ + public void put(String name, DMatrix dmat) { + watchList.add(new AbstractMap.SimpleEntry<>(name, dmat)); + } + + public int size() { + return watchList.size(); + } + + @Override + public Iterator> iterator() { + return watchList.iterator(); + } +} diff --git a/java/xgboost4j/src/main/resources/lib/README.md b/java/xgboost4j/src/main/resources/lib/README.md deleted file mode 100644 index 9c4e25ae2..000000000 --- a/java/xgboost4j/src/main/resources/lib/README.md +++ /dev/null @@ -1 +0,0 @@ -please put native library in this package.