rm WatchList class, take Iterable<Entry<String, DMatrix>> as eval param, change Params to Iterable<Entry<String, Object>>

This commit is contained in:
yanqingmen 2015-06-10 23:34:52 -07:00
parent 8c5d3ac130
commit 4e8a1c6516
14 changed files with 136 additions and 129 deletions

View File

@ -73,14 +73,11 @@ dmat.setWeight(weights);
``` ```
#### Setting Parameters #### Setting Parameters
* A util class ```Params``` in xgboost4j is used to handle parameters. * in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters.
* To import ```Params``` :
* to set parameters, for non-multiple value params, you can simply use entrySet of an Map:
```java ```java
import org.dmlc.xgboost4j.util.Params; Map<String, Object> paramMap = new HashMap<>() {
```
* to set parameters :
```java
Params params = new Params() {
{ {
put("eta", 1.0); put("eta", 1.0);
put("max_depth", 2); put("max_depth", 2);
@ -89,18 +86,17 @@ Params params = new Params() {
put("eval_metric", "logloss"); put("eval_metric", "logloss");
} }
}; };
Iterable<Entry<String, Object>> params = paramMap.entrySet();
``` ```
* Multiple values with same param key is handled naturally in ```Params```, e.g. : * for the situation that multiple values with same param key, List<Entry<String, Object>> would be a good choice, e.g. :
```java ```java
Params params = new Params() { List<Entry<String, Object>> params = new ArrayList<Entry<String, Object>>() {
{ {
put("eta", 1.0); add(new SimpleEntry<String, Object>("eta", 1.0));
put("max_depth", 2); add(new SimpleEntry<String, Object>("max_depth", 2.0));
put("silent", 1); add(new SimpleEntry<String, Object>("silent", 1));
put("objective", "binary:logistic"); add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
put("eval_metric", "logloss"); }
put("eval_metric", "error");
}
}; };
``` ```
@ -110,7 +106,6 @@ With parameters and data, you are able to train a booster model.
```java ```java
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
``` ```
* Training * Training
@ -118,9 +113,10 @@ import org.dmlc.xgboost4j.util.WatchList;
DMatrix trainMat = new DMatrix("train.svm.txt"); DMatrix trainMat = new DMatrix("train.svm.txt");
DMatrix validMat = new DMatrix("valid.svm.txt"); DMatrix validMat = new DMatrix("valid.svm.txt");
//specifiy a watchList to see the performance //specifiy a watchList to see the performance
WatchList watchs = new WatchList(); //any Iterable<Entry<String, DMatrix>> object could be used as watchList
watchs.put("train", trainMat); List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.put("test", testMat); watchs.add(new SimpleEntry<>("train", trainMat));
watchs.add(new SimpleEntry<>("test", testMat));
int round = 2; int round = 2;
Booster booster = Trainer.train(params, trainMat, round, watchs, null, null); Booster booster = Trainer.train(params, trainMat, round, watchs, null, null);
``` ```

View File

@ -18,13 +18,19 @@ package org.dmlc.xgboost4j.demo;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import java.util.AbstractMap;
import java.util.AbstractMap.SimpleEntry;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.DataLoader; import org.dmlc.xgboost4j.demo.util.DataLoader;
import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/** /**
* a simple example of java wrapper for xgboost * a simple example of java wrapper for xgboost
@ -51,8 +57,32 @@ public class BasicWalkThrough {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters //specify parameters
Params param = new Params() { //note: any Iterable<Entry<String, Object>> object would be used as paramters
//e.g.
// Map<String, Object> paramMap = new HashMap<String, Object>() {
// {
// put("eta", 1.0);
// put("max_depth", 2);
// put("silent", 1);
// put("objective", "binary:logistic");
// }
// };
// Iterable<Entry<String, Object>> param = paramMap.entrySet();
//or
// List<Entry<String, Object>> param = new ArrayList<Entry<String, Object>>() {
// {
// add(new SimpleEntry<String, Object>("eta", 1.0));
// add(new SimpleEntry<String, Object>("max_depth", 2.0));
// add(new SimpleEntry<String, Object>("silent", 1));
// add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
// }
// };
//we use a util class Params to handle parameters as example
Iterable<Entry<String, Object>> param = new Params() {
{ {
put("eta", 1.0); put("eta", 1.0);
put("max_depth", 2); put("max_depth", 2);
@ -61,10 +91,21 @@ public class BasicWalkThrough {
} }
}; };
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat); //specify watchList to set evaluation dmats
watchs.put("test", testMat); //note: any Iterable<Entry<String, DMatrix>> object would be used as watchList
//e.g.
//an entrySet of Map is good
// Map<String, DMatrix> watchMap = new HashMap<>();
// watchMap.put("train", trainMat);
// watchMap.put("test", testMat);
// Iterable<Entry<String, DMatrix>> watchs = watchMap.entrySet();
//we use a List of Entry<String, DMatrix> WatchList as example
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new SimpleEntry<>("train", trainMat));
watchs.add(new SimpleEntry<>("test", testMat));
//set round //set round
int round = 2; int round = 2;
@ -110,9 +151,9 @@ public class BasicWalkThrough {
trainMat2.setLabel(spData.labels); trainMat2.setLabel(spData.labels);
//specify watchList //specify watchList
WatchList watchs2 = new WatchList(); List<Entry<String, DMatrix>> watchs2 = new ArrayList<>();
watchs2.put("train", trainMat2); watchs2.add(new SimpleEntry<>("train", trainMat2));
watchs2.put("test", testMat); watchs2.add(new SimpleEntry<>("test", testMat2));
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null); Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
float[][] predicts3 = booster3.predict(testMat2); float[][] predicts3 = booster3.predict(testMat2);

View File

@ -15,11 +15,14 @@
*/ */
package org.dmlc.xgboost4j.demo; package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/** /**
* example for start from a initial base prediction * example for start from a initial base prediction
@ -43,10 +46,10 @@ public class BoostFromPrediction {
} }
}; };
//specify watchList //specify watchList
WatchList watchs = new WatchList(); List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.put("train", trainMat); watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.put("test", testMat); watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train xgboost for 1 round //train xgboost for 1 round
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null); Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);

View File

@ -18,7 +18,7 @@ package org.dmlc.xgboost4j.demo;
import java.io.IOException; import java.io.IOException;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
/** /**
* an example of cross validation * an example of cross validation

View File

@ -15,15 +15,16 @@
*/ */
package org.dmlc.xgboost4j.demo; package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.IEvaluation; import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective; import org.dmlc.xgboost4j.IObjective;
import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/** /**
* an example user define objective and eval * an example user define objective and eval
@ -140,9 +141,9 @@ public class CustomObjective {
int round = 2; int round = 2;
//specify watchList //specify watchList
WatchList watchs = new WatchList(); List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.put("train", trainMat); watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.put("test", testMat); watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//user define obj and eval //user define obj and eval
IObjective obj = new LogRegObj(); IObjective obj = new LogRegObj();

View File

@ -15,11 +15,14 @@
*/ */
package org.dmlc.xgboost4j.demo; package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/** /**
* simple example for using external memory version * simple example for using external memory version
@ -48,9 +51,9 @@ public class ExternalMemory {
//param.put("nthread", num_real_cpu); //param.put("nthread", num_real_cpu);
//specify watchList //specify watchList
WatchList watchs = new WatchList(); List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.put("train", trainMat); watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.put("test", testMat); watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//set round //set round
int round = 2; int round = 2;

View File

@ -15,12 +15,15 @@
*/ */
package org.dmlc.xgboost4j.demo; package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.CustomEval; import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/** /**
* this is an example of fit generalized linear model in xgboost * this is an example of fit generalized linear model in xgboost
@ -54,9 +57,9 @@ public class GeneralizedLinearModel {
//specify watchList //specify watchList
WatchList watchs = new WatchList(); List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.put("train", trainMat); watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.put("test", testMat); watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster //train a booster
int round = 4; int round = 4;

View File

@ -15,13 +15,16 @@
*/ */
package org.dmlc.xgboost4j.demo; package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.CustomEval; import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.util.WatchList; import org.dmlc.xgboost4j.demo.util.Params;
/** /**
* predict first ntree * predict first ntree
@ -44,9 +47,9 @@ public class PredictFirstNtree {
}; };
//specify watchList //specify watchList
WatchList watchs = new WatchList(); List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.put("train", trainMat); watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.put("test", testMat); watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster //train a booster
int round = 3; int round = 3;

View File

@ -15,12 +15,15 @@
*/ */
package org.dmlc.xgboost4j.demo; package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList; import org.dmlc.xgboost4j.demo.util.Params;
/** /**
* predict leaf indices * predict leaf indices
@ -43,9 +46,9 @@ public class PredictLeafIndices {
}; };
//specify watchList //specify watchList
WatchList watchs = new WatchList(); List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.put("train", trainMat); watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.put("test", testMat); watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster //train a booster
int round = 3; int round = 3;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package org.dmlc.xgboost4j.util; package org.dmlc.xgboost4j.demo.util;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator; import java.util.Iterator;

View File

@ -25,11 +25,11 @@ import java.io.UnsupportedEncodingException;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer; import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.wrapper.XgboostJNI; import org.dmlc.xgboost4j.wrapper.XgboostJNI;
@ -58,7 +58,7 @@ public final class Booster {
* @param params parameters * @param params parameters
* @param dMatrixs DMatrix array * @param dMatrixs DMatrix array
*/ */
public Booster(Params params, DMatrix[] dMatrixs) { public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) {
init(dMatrixs); init(dMatrixs);
setParam("seed","0"); setParam("seed","0");
setParams(params); setParams(params);
@ -71,7 +71,7 @@ public final class Booster {
* @param params parameters * @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel) * @param modelPath booster modelPath (model generated by booster.saveModel)
*/ */
public Booster(Params params, String modelPath) { public Booster(Iterable<Entry<String, Object>> params, String modelPath) {
handle = XgboostJNI.XGBoosterCreate(new long[] {}); handle = XgboostJNI.XGBoosterCreate(new long[] {});
loadModel(modelPath); loadModel(modelPath);
setParam("seed","0"); setParam("seed","0");
@ -102,7 +102,7 @@ public final class Booster {
* set parameters * set parameters
* @param params parameters key-value map * @param params parameters key-value map
*/ */
public void setParams(Params params) { public void setParams(Iterable<Entry<String, Object>> params) {
if(params!=null) { if(params!=null) {
for(Map.Entry<String, Object> entry : params) { for(Map.Entry<String, Object> entry : params) {
setParam(entry.getKey(), entry.getValue().toString()); setParam(entry.getKey(), entry.getValue().toString());

View File

@ -15,6 +15,7 @@
*/ */
package org.dmlc.xgboost4j.util; package org.dmlc.xgboost4j.util;
import java.util.Map;
import org.dmlc.xgboost4j.IEvaluation; import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
@ -37,7 +38,7 @@ public class CVPack {
* @param dtest test data * @param dtest test data
* @param params parameters * @param params parameters
*/ */
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) { public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) {
dmats = new DMatrix[] {dtrain, dtest}; dmats = new DMatrix[] {dtrain, dtest};
booster = new Booster(params, dmats); booster = new Booster(params, dmats);
names = new String[] {"train", "test"}; names = new String[] {"train", "test"};

View File

@ -46,21 +46,23 @@ public class Trainer {
* @param eval customized evaluation (set to null if not used) * @param eval customized evaluation (set to null if not used)
* @return trained booster * @return trained booster
*/ */
public static Booster train(Params params, DMatrix dtrain, int round, public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
WatchList watchs, IObjective obj, IEvaluation eval) { Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) {
//collect eval matrixs //collect eval matrixs
int len = watchs.size(); String[] evalNames;
int i = 0; DMatrix[] evalMats;
String[] evalNames = new String[len]; List<String> names = new ArrayList<>();
DMatrix[] evalMats = new DMatrix[len]; List<DMatrix> mats = new ArrayList<>();
for(Entry<String, DMatrix> evalEntry : watchs) { for(Entry<String, DMatrix> evalEntry : watchs) {
evalNames[i] = evalEntry.getKey(); names.add(evalEntry.getKey());
evalMats[i] = evalEntry.getValue(); mats.add(evalEntry.getValue());
i++;
} }
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
//collect all data matrixs //collect all data matrixs
DMatrix[] allMats; DMatrix[] allMats;
if(evalMats!=null && evalMats.length>0) { if(evalMats!=null && evalMats.length>0) {
@ -110,7 +112,7 @@ public class Trainer {
* @param eval customized evaluation (set to null if not used) * @param eval customized evaluation (set to null if not used)
* @return evaluation history * @return evaluation history
*/ */
public static String[] crossValiation(Params params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) { public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics); CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
String[] evalHist = new String[round]; String[] evalHist = new String[round];
String[] results = new String[cvPacks.length]; String[] results = new String[cvPacks.length];
@ -147,7 +149,7 @@ public class Trainer {
* @param evalMetrics Evaluation metrics * @param evalMetrics Evaluation metrics
* @return CV package array * @return CV package array
*/ */
public static CVPack[] makeNFold(DMatrix data, int nfold, Params params, String[] evalMetrics) { public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) {
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum()); List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
int step = samples.size()/nfold; int step = samples.size()/nfold;
int[] testSlice = new int[step]; int[] testSlice = new int[step];

View File

@ -1,49 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import org.dmlc.xgboost4j.DMatrix;
/**
* class to handle evaluation dmatrix
* @author hzx
*/
public class WatchList implements Iterable<Entry<String, DMatrix> >{
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
/**
* put eval dmatrix and it's name
* @param name
* @param dmat
*/
public void put(String name, DMatrix dmat) {
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
}
public int size() {
return watchList.size();
}
@Override
public Iterator<Entry<String, DMatrix>> iterator() {
return watchList.iterator();
}
}