From 4e8a1c65168f314e8b34bde1fa01ec91e4bf87be Mon Sep 17 00:00:00 2001 From: yanqingmen Date: Wed, 10 Jun 2015 23:34:52 -0700 Subject: [PATCH] rm WatchList class, take Iterable> as eval param, change Params to Iterable> --- java/doc/xgboost4j.md | 38 ++++++------ .../dmlc/xgboost4j/demo/BasicWalkThrough.java | 61 ++++++++++++++++--- .../xgboost4j/demo/BoostFromPrediction.java | 15 +++-- .../dmlc/xgboost4j/demo/CrossValidation.java | 2 +- .../dmlc/xgboost4j/demo/CustomObjective.java | 11 ++-- .../dmlc/xgboost4j/demo/ExternalMemory.java | 13 ++-- .../demo/GeneralizedLinearModel.java | 13 ++-- .../xgboost4j/demo/PredictFirstNtree.java | 13 ++-- .../xgboost4j/demo/PredictLeafIndices.java | 13 ++-- .../org/dmlc/xgboost4j/demo}/util/Params.java | 2 +- .../main/java/org/dmlc/xgboost4j/Booster.java | 8 +-- .../java/org/dmlc/xgboost4j/util/CVPack.java | 3 +- .../java/org/dmlc/xgboost4j/util/Trainer.java | 24 ++++---- .../org/dmlc/xgboost4j/util/WatchList.java | 49 --------------- 14 files changed, 136 insertions(+), 129 deletions(-) rename java/{xgboost4j/src/main/java/org/dmlc/xgboost4j => xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo}/util/Params.java (97%) delete mode 100644 java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/WatchList.java diff --git a/java/doc/xgboost4j.md b/java/doc/xgboost4j.md index b383e9a04..201b3cc05 100644 --- a/java/doc/xgboost4j.md +++ b/java/doc/xgboost4j.md @@ -73,14 +73,11 @@ dmat.setWeight(weights); ``` #### Setting Parameters -* A util class ```Params``` in xgboost4j is used to handle parameters. -* To import ```Params``` : +* in xgboost4j any ```Iterable>``` object could be used as parameters. + +* to set parameters, for non-multiple value params, you can simply use entrySet of an Map: ```java -import org.dmlc.xgboost4j.util.Params; -``` -* to set parameters : -```java -Params params = new Params() { +Map paramMap = new HashMap<>() { { put("eta", 1.0); put("max_depth", 2); @@ -89,18 +86,17 @@ Params params = new Params() { put("eval_metric", "logloss"); } }; +Iterable> params = paramMap.entrySet(); ``` -* Multiple values with same param key is handled naturally in ```Params```, e.g. : +* for the situation that multiple values with same param key, List> would be a good choice, e.g. : ```java -Params params = new Params() { - { - put("eta", 1.0); - put("max_depth", 2); - put("silent", 1); - put("objective", "binary:logistic"); - put("eval_metric", "logloss"); - put("eval_metric", "error"); - } +List> params = new ArrayList>() { + { + add(new SimpleEntry("eta", 1.0)); + add(new SimpleEntry("max_depth", 2.0)); + add(new SimpleEntry("silent", 1)); + add(new SimpleEntry("objective", "binary:logistic")); + } }; ``` @@ -110,7 +106,6 @@ With parameters and data, you are able to train a booster model. ```java import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.WatchList; ``` * Training @@ -118,9 +113,10 @@ import org.dmlc.xgboost4j.util.WatchList; DMatrix trainMat = new DMatrix("train.svm.txt"); DMatrix validMat = new DMatrix("valid.svm.txt"); //specifiy a watchList to see the performance -WatchList watchs = new WatchList(); -watchs.put("train", trainMat); -watchs.put("test", testMat); +//any Iterable> object could be used as watchList +List> watchs = new ArrayList<>(); +watchs.add(new SimpleEntry<>("train", trainMat)); +watchs.add(new SimpleEntry<>("test", testMat)); int round = 2; Booster booster = Trainer.train(params, trainMat, round, watchs, null, null); ``` diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java index 0a80ae314..a0c7a3ae1 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java @@ -18,13 +18,19 @@ package org.dmlc.xgboost4j.demo; import java.io.File; import java.io.IOException; import java.io.UnsupportedEncodingException; +import java.util.AbstractMap; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.demo.util.DataLoader; -import org.dmlc.xgboost4j.util.Params; +import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.WatchList; /** * a simple example of java wrapper for xgboost @@ -51,8 +57,32 @@ public class BasicWalkThrough { DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + //specify parameters - Params param = new Params() { + //note: any Iterable> object would be used as paramters + //e.g. + // Map paramMap = new HashMap() { + // { + // put("eta", 1.0); + // put("max_depth", 2); + // put("silent", 1); + // put("objective", "binary:logistic"); + // } + // }; + // Iterable> param = paramMap.entrySet(); + + //or + // List> param = new ArrayList>() { + // { + // add(new SimpleEntry("eta", 1.0)); + // add(new SimpleEntry("max_depth", 2.0)); + // add(new SimpleEntry("silent", 1)); + // add(new SimpleEntry("objective", "binary:logistic")); + // } + // }; + + //we use a util class Params to handle parameters as example + Iterable> param = new Params() { { put("eta", 1.0); put("max_depth", 2); @@ -61,10 +91,21 @@ public class BasicWalkThrough { } }; - //specify watchList - WatchList watchs = new WatchList(); - watchs.put("train", trainMat); - watchs.put("test", testMat); + + + //specify watchList to set evaluation dmats + //note: any Iterable> object would be used as watchList + //e.g. + //an entrySet of Map is good + // Map watchMap = new HashMap<>(); + // watchMap.put("train", trainMat); + // watchMap.put("test", testMat); + // Iterable> watchs = watchMap.entrySet(); + + //we use a List of Entry WatchList as example + List> watchs = new ArrayList<>(); + watchs.add(new SimpleEntry<>("train", trainMat)); + watchs.add(new SimpleEntry<>("test", testMat)); //set round int round = 2; @@ -110,9 +151,9 @@ public class BasicWalkThrough { trainMat2.setLabel(spData.labels); //specify watchList - WatchList watchs2 = new WatchList(); - watchs2.put("train", trainMat2); - watchs2.put("test", testMat); + List> watchs2 = new ArrayList<>(); + watchs2.add(new SimpleEntry<>("train", trainMat2)); + watchs2.add(new SimpleEntry<>("test", testMat2)); Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null); float[][] predicts3 = booster3.predict(testMat2); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java index 54c8c1bee..733c49503 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java @@ -15,11 +15,14 @@ */ package org.dmlc.xgboost4j.demo; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.util.Params; +import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.WatchList; /** * example for start from a initial base prediction @@ -43,10 +46,10 @@ public class BoostFromPrediction { } }; - //specify watchList - WatchList watchs = new WatchList(); - watchs.put("train", trainMat); - watchs.put("test", testMat); + //specify watchList + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); //train xgboost for 1 round Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java index 793ffb61d..0c470bf17 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java @@ -18,7 +18,7 @@ package org.dmlc.xgboost4j.demo; import java.io.IOException; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.Params; +import org.dmlc.xgboost4j.demo.util.Params; /** * an example of cross validation diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java index ed8c9a9a9..03c9c4b52 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java @@ -15,15 +15,16 @@ */ package org.dmlc.xgboost4j.demo; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.IEvaluation; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.IObjective; -import org.dmlc.xgboost4j.util.Params; +import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.WatchList; /** * an example user define objective and eval @@ -140,9 +141,9 @@ public class CustomObjective { int round = 2; //specify watchList - WatchList watchs = new WatchList(); - watchs.put("train", trainMat); - watchs.put("test", testMat); + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); //user define obj and eval IObjective obj = new LogRegObj(); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java index 698245bf1..6ac687289 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java @@ -15,11 +15,14 @@ */ package org.dmlc.xgboost4j.demo; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.util.Params; +import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.WatchList; /** * simple example for using external memory version @@ -48,9 +51,9 @@ public class ExternalMemory { //param.put("nthread", num_real_cpu); //specify watchList - WatchList watchs = new WatchList(); - watchs.put("train", trainMat); - watchs.put("test", testMat); + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); //set round int round = 2; diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java index a9b3ba8df..2a20edbff 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java @@ -15,12 +15,15 @@ */ package org.dmlc.xgboost4j.demo; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.demo.util.CustomEval; -import org.dmlc.xgboost4j.util.Params; +import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.WatchList; /** * this is an example of fit generalized linear model in xgboost @@ -54,9 +57,9 @@ public class GeneralizedLinearModel { //specify watchList - WatchList watchs = new WatchList(); - watchs.put("train", trainMat); - watchs.put("test", testMat); + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); //train a booster int round = 4; diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java index bfcc04d06..8e3f3abfb 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java @@ -15,13 +15,16 @@ */ package org.dmlc.xgboost4j.demo; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.demo.util.CustomEval; -import org.dmlc.xgboost4j.util.WatchList; +import org.dmlc.xgboost4j.demo.util.Params; /** * predict first ntree @@ -44,9 +47,9 @@ public class PredictFirstNtree { }; //specify watchList - WatchList watchs = new WatchList(); - watchs.put("train", trainMat); - watchs.put("test", testMat); + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); //train a booster int round = 3; diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java index 5f1c2e5ac..697f40379 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java @@ -15,12 +15,15 @@ */ package org.dmlc.xgboost4j.demo; +import java.util.AbstractMap; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.Map; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.WatchList; +import org.dmlc.xgboost4j.demo.util.Params; /** * predict leaf indices @@ -43,9 +46,9 @@ public class PredictLeafIndices { }; //specify watchList - WatchList watchs = new WatchList(); - watchs.put("train", trainMat); - watchs.put("test", testMat); + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); //train a booster int round = 3; diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Params.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/Params.java similarity index 97% rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Params.java rename to java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/Params.java index 582620174..0f4c5c738 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Params.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/Params.java @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package org.dmlc.xgboost4j.util; +package org.dmlc.xgboost4j.demo.util; import java.util.ArrayList; import java.util.Iterator; diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java index 3140b184e..c5d8b1006 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java @@ -25,11 +25,11 @@ import java.io.UnsupportedEncodingException; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.dmlc.xgboost4j.util.Initializer; -import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.wrapper.XgboostJNI; @@ -58,7 +58,7 @@ public final class Booster { * @param params parameters * @param dMatrixs DMatrix array */ - public Booster(Params params, DMatrix[] dMatrixs) { + public Booster(Iterable> params, DMatrix[] dMatrixs) { init(dMatrixs); setParam("seed","0"); setParams(params); @@ -71,7 +71,7 @@ public final class Booster { * @param params parameters * @param modelPath booster modelPath (model generated by booster.saveModel) */ - public Booster(Params params, String modelPath) { + public Booster(Iterable> params, String modelPath) { handle = XgboostJNI.XGBoosterCreate(new long[] {}); loadModel(modelPath); setParam("seed","0"); @@ -102,7 +102,7 @@ public final class Booster { * set parameters * @param params parameters key-value map */ - public void setParams(Params params) { + public void setParams(Iterable> params) { if(params!=null) { for(Map.Entry entry : params) { setParam(entry.getKey(), entry.getValue().toString()); diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java index a0d145636..3e67dc669 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java @@ -15,6 +15,7 @@ */ package org.dmlc.xgboost4j.util; +import java.util.Map; import org.dmlc.xgboost4j.IEvaluation; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; @@ -37,7 +38,7 @@ public class CVPack { * @param dtest test data * @param params parameters */ - public CVPack(DMatrix dtrain, DMatrix dtest, Params params) { + public CVPack(DMatrix dtrain, DMatrix dtest, Iterable> params) { dmats = new DMatrix[] {dtrain, dtest}; booster = new Booster(params, dmats); names = new String[] {"train", "test"}; diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java index a53437477..8a336b1a8 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java @@ -46,21 +46,23 @@ public class Trainer { * @param eval customized evaluation (set to null if not used) * @return trained booster */ - public static Booster train(Params params, DMatrix dtrain, int round, - WatchList watchs, IObjective obj, IEvaluation eval) { + public static Booster train(Iterable> params, DMatrix dtrain, int round, + Iterable> watchs, IObjective obj, IEvaluation eval) { //collect eval matrixs - int len = watchs.size(); - int i = 0; - String[] evalNames = new String[len]; - DMatrix[] evalMats = new DMatrix[len]; + String[] evalNames; + DMatrix[] evalMats; + List names = new ArrayList<>(); + List mats = new ArrayList<>(); for(Entry evalEntry : watchs) { - evalNames[i] = evalEntry.getKey(); - evalMats[i] = evalEntry.getValue(); - i++; + names.add(evalEntry.getKey()); + mats.add(evalEntry.getValue()); } + evalNames = names.toArray(new String[names.size()]); + evalMats = mats.toArray(new DMatrix[mats.size()]); + //collect all data matrixs DMatrix[] allMats; if(evalMats!=null && evalMats.length>0) { @@ -110,7 +112,7 @@ public class Trainer { * @param eval customized evaluation (set to null if not used) * @return evaluation history */ - public static String[] crossValiation(Params params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) { + public static String[] crossValiation(Iterable> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) { CVPack[] cvPacks = makeNFold(data, nfold, params, metrics); String[] evalHist = new String[round]; String[] results = new String[cvPacks.length]; @@ -147,7 +149,7 @@ public class Trainer { * @param evalMetrics Evaluation metrics * @return CV package array */ - public static CVPack[] makeNFold(DMatrix data, int nfold, Params params, String[] evalMetrics) { + public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable> params, String[] evalMetrics) { List samples = genRandPermutationNums(0, (int) data.rowNum()); int step = samples.size()/nfold; int[] testSlice = new int[step]; diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/WatchList.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/WatchList.java deleted file mode 100644 index a08b96208..000000000 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/WatchList.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.util; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map.Entry; -import org.dmlc.xgboost4j.DMatrix; - -/** - * class to handle evaluation dmatrix - * @author hzx - */ -public class WatchList implements Iterable >{ - List> watchList = new ArrayList<>(); - - /** - * put eval dmatrix and it's name - * @param name - * @param dmat - */ - public void put(String name, DMatrix dmat) { - watchList.add(new AbstractMap.SimpleEntry<>(name, dmat)); - } - - public int size() { - return watchList.size(); - } - - @Override - public Iterator> iterator() { - return watchList.iterator(); - } -}