rm WatchList class, take Iterable<Entry<String, DMatrix>> as eval param, change Params to Iterable<Entry<String, Object>>

This commit is contained in:
yanqingmen 2015-06-10 23:34:52 -07:00
parent 8c5d3ac130
commit 4e8a1c6516
14 changed files with 136 additions and 129 deletions

View File

@ -73,14 +73,11 @@ dmat.setWeight(weights);
```
#### Setting Parameters
* A util class ```Params``` in xgboost4j is used to handle parameters.
* To import ```Params``` :
* in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters.
* to set parameters, for non-multiple value params, you can simply use entrySet of an Map:
```java
import org.dmlc.xgboost4j.util.Params;
```
* to set parameters :
```java
Params params = new Params() {
Map<String, Object> paramMap = new HashMap<>() {
{
put("eta", 1.0);
put("max_depth", 2);
@ -89,18 +86,17 @@ Params params = new Params() {
put("eval_metric", "logloss");
}
};
Iterable<Entry<String, Object>> params = paramMap.entrySet();
```
* Multiple values with same param key is handled naturally in ```Params```, e.g. :
* for the situation that multiple values with same param key, List<Entry<String, Object>> would be a good choice, e.g. :
```java
Params params = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
put("eval_metric", "logloss");
put("eval_metric", "error");
}
List<Entry<String, Object>> params = new ArrayList<Entry<String, Object>>() {
{
add(new SimpleEntry<String, Object>("eta", 1.0));
add(new SimpleEntry<String, Object>("max_depth", 2.0));
add(new SimpleEntry<String, Object>("silent", 1));
add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
}
};
```
@ -110,7 +106,6 @@ With parameters and data, you are able to train a booster model.
```java
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
```
* Training
@ -118,9 +113,10 @@ import org.dmlc.xgboost4j.util.WatchList;
DMatrix trainMat = new DMatrix("train.svm.txt");
DMatrix validMat = new DMatrix("valid.svm.txt");
//specifiy a watchList to see the performance
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
//any Iterable<Entry<String, DMatrix>> object could be used as watchList
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new SimpleEntry<>("train", trainMat));
watchs.add(new SimpleEntry<>("test", testMat));
int round = 2;
Booster booster = Trainer.train(params, trainMat, round, watchs, null, null);
```

View File

@ -18,13 +18,19 @@ package org.dmlc.xgboost4j.demo;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.AbstractMap;
import java.util.AbstractMap.SimpleEntry;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.DataLoader;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* a simple example of java wrapper for xgboost
@ -51,8 +57,32 @@ public class BasicWalkThrough {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
//note: any Iterable<Entry<String, Object>> object would be used as paramters
//e.g.
// Map<String, Object> paramMap = new HashMap<String, Object>() {
// {
// put("eta", 1.0);
// put("max_depth", 2);
// put("silent", 1);
// put("objective", "binary:logistic");
// }
// };
// Iterable<Entry<String, Object>> param = paramMap.entrySet();
//or
// List<Entry<String, Object>> param = new ArrayList<Entry<String, Object>>() {
// {
// add(new SimpleEntry<String, Object>("eta", 1.0));
// add(new SimpleEntry<String, Object>("max_depth", 2.0));
// add(new SimpleEntry<String, Object>("silent", 1));
// add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
// }
// };
//we use a util class Params to handle parameters as example
Iterable<Entry<String, Object>> param = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
@ -61,10 +91,21 @@ public class BasicWalkThrough {
}
};
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
//specify watchList to set evaluation dmats
//note: any Iterable<Entry<String, DMatrix>> object would be used as watchList
//e.g.
//an entrySet of Map is good
// Map<String, DMatrix> watchMap = new HashMap<>();
// watchMap.put("train", trainMat);
// watchMap.put("test", testMat);
// Iterable<Entry<String, DMatrix>> watchs = watchMap.entrySet();
//we use a List of Entry<String, DMatrix> WatchList as example
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new SimpleEntry<>("train", trainMat));
watchs.add(new SimpleEntry<>("test", testMat));
//set round
int round = 2;
@ -110,9 +151,9 @@ public class BasicWalkThrough {
trainMat2.setLabel(spData.labels);
//specify watchList
WatchList watchs2 = new WatchList();
watchs2.put("train", trainMat2);
watchs2.put("test", testMat);
List<Entry<String, DMatrix>> watchs2 = new ArrayList<>();
watchs2.add(new SimpleEntry<>("train", trainMat2));
watchs2.add(new SimpleEntry<>("test", testMat2));
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
float[][] predicts3 = booster3.predict(testMat2);

View File

@ -15,11 +15,14 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* example for start from a initial base prediction
@ -43,10 +46,10 @@ public class BoostFromPrediction {
}
};
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
//specify watchList
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train xgboost for 1 round
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);

View File

@ -18,7 +18,7 @@ package org.dmlc.xgboost4j.demo;
import java.io.IOException;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
/**
* an example of cross validation

View File

@ -15,15 +15,16 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* an example user define objective and eval
@ -140,9 +141,9 @@ public class CustomObjective {
int round = 2;
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//user define obj and eval
IObjective obj = new LogRegObj();

View File

@ -15,11 +15,14 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* simple example for using external memory version
@ -48,9 +51,9 @@ public class ExternalMemory {
//param.put("nthread", num_real_cpu);
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//set round
int round = 2;

View File

@ -15,12 +15,15 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* this is an example of fit generalized linear model in xgboost
@ -54,9 +57,9 @@ public class GeneralizedLinearModel {
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 4;

View File

@ -15,13 +15,16 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.util.WatchList;
import org.dmlc.xgboost4j.demo.util.Params;
/**
* predict first ntree
@ -44,9 +47,9 @@ public class PredictFirstNtree {
};
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 3;

View File

@ -15,12 +15,15 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
import org.dmlc.xgboost4j.demo.util.Params;
/**
* predict leaf indices
@ -43,9 +46,9 @@ public class PredictLeafIndices {
};
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 3;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
package org.dmlc.xgboost4j.demo.util;
import java.util.ArrayList;
import java.util.Iterator;

View File

@ -25,11 +25,11 @@ import java.io.UnsupportedEncodingException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
@ -58,7 +58,7 @@ public final class Booster {
* @param params parameters
* @param dMatrixs DMatrix array
*/
public Booster(Params params, DMatrix[] dMatrixs) {
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) {
init(dMatrixs);
setParam("seed","0");
setParams(params);
@ -71,7 +71,7 @@ public final class Booster {
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
*/
public Booster(Params params, String modelPath) {
public Booster(Iterable<Entry<String, Object>> params, String modelPath) {
handle = XgboostJNI.XGBoosterCreate(new long[] {});
loadModel(modelPath);
setParam("seed","0");
@ -102,7 +102,7 @@ public final class Booster {
* set parameters
* @param params parameters key-value map
*/
public void setParams(Params params) {
public void setParams(Iterable<Entry<String, Object>> params) {
if(params!=null) {
for(Map.Entry<String, Object> entry : params) {
setParam(entry.getKey(), entry.getValue().toString());

View File

@ -15,6 +15,7 @@
*/
package org.dmlc.xgboost4j.util;
import java.util.Map;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
@ -37,7 +38,7 @@ public class CVPack {
* @param dtest test data
* @param params parameters
*/
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) {
dmats = new DMatrix[] {dtrain, dtest};
booster = new Booster(params, dmats);
names = new String[] {"train", "test"};

View File

@ -46,21 +46,23 @@ public class Trainer {
* @param eval customized evaluation (set to null if not used)
* @return trained booster
*/
public static Booster train(Params params, DMatrix dtrain, int round,
WatchList watchs, IObjective obj, IEvaluation eval) {
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) {
//collect eval matrixs
int len = watchs.size();
int i = 0;
String[] evalNames = new String[len];
DMatrix[] evalMats = new DMatrix[len];
String[] evalNames;
DMatrix[] evalMats;
List<String> names = new ArrayList<>();
List<DMatrix> mats = new ArrayList<>();
for(Entry<String, DMatrix> evalEntry : watchs) {
evalNames[i] = evalEntry.getKey();
evalMats[i] = evalEntry.getValue();
i++;
names.add(evalEntry.getKey());
mats.add(evalEntry.getValue());
}
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
//collect all data matrixs
DMatrix[] allMats;
if(evalMats!=null && evalMats.length>0) {
@ -110,7 +112,7 @@ public class Trainer {
* @param eval customized evaluation (set to null if not used)
* @return evaluation history
*/
public static String[] crossValiation(Params params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
String[] evalHist = new String[round];
String[] results = new String[cvPacks.length];
@ -147,7 +149,7 @@ public class Trainer {
* @param evalMetrics Evaluation metrics
* @return CV package array
*/
public static CVPack[] makeNFold(DMatrix data, int nfold, Params params, String[] evalMetrics) {
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) {
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
int step = samples.size()/nfold;
int[] testSlice = new int[step];

View File

@ -1,49 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import org.dmlc.xgboost4j.DMatrix;
/**
* class to handle evaluation dmatrix
* @author hzx
*/
public class WatchList implements Iterable<Entry<String, DMatrix> >{
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
/**
* put eval dmatrix and it's name
* @param name
* @param dmat
*/
public void put(String name, DMatrix dmat) {
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
}
public int size() {
return watchList.size();
}
@Override
public Iterator<Entry<String, DMatrix>> iterator() {
return watchList.iterator();
}
}