rm WatchList class, take Iterable<Entry<String, DMatrix>> as eval param, change Params to Iterable<Entry<String, Object>>

This commit is contained in:
yanqingmen
2015-06-10 23:34:52 -07:00
parent 8c5d3ac130
commit 4e8a1c6516
14 changed files with 136 additions and 129 deletions

View File

@@ -73,14 +73,11 @@ dmat.setWeight(weights);
```
#### Setting Parameters
* A util class ```Params``` in xgboost4j is used to handle parameters.
* To import ```Params``` :
* in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters.
* to set parameters, for non-multiple value params, you can simply use entrySet of an Map:
```java
import org.dmlc.xgboost4j.util.Params;
```
* to set parameters :
```java
Params params = new Params() {
Map<String, Object> paramMap = new HashMap<>() {
{
put("eta", 1.0);
put("max_depth", 2);
@@ -89,18 +86,17 @@ Params params = new Params() {
put("eval_metric", "logloss");
}
};
Iterable<Entry<String, Object>> params = paramMap.entrySet();
```
* Multiple values with same param key is handled naturally in ```Params```, e.g. :
* for the situation that multiple values with same param key, List<Entry<String, Object>> would be a good choice, e.g. :
```java
Params params = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
put("eval_metric", "logloss");
put("eval_metric", "error");
}
List<Entry<String, Object>> params = new ArrayList<Entry<String, Object>>() {
{
add(new SimpleEntry<String, Object>("eta", 1.0));
add(new SimpleEntry<String, Object>("max_depth", 2.0));
add(new SimpleEntry<String, Object>("silent", 1));
add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
}
};
```
@@ -110,7 +106,6 @@ With parameters and data, you are able to train a booster model.
```java
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
```
* Training
@@ -118,9 +113,10 @@ import org.dmlc.xgboost4j.util.WatchList;
DMatrix trainMat = new DMatrix("train.svm.txt");
DMatrix validMat = new DMatrix("valid.svm.txt");
//specifiy a watchList to see the performance
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
//any Iterable<Entry<String, DMatrix>> object could be used as watchList
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new SimpleEntry<>("train", trainMat));
watchs.add(new SimpleEntry<>("test", testMat));
int round = 2;
Booster booster = Trainer.train(params, trainMat, round, watchs, null, null);
```