rm WatchList class, take Iterable<Entry<String, DMatrix>> as eval param, change Params to Iterable<Entry<String, Object>>

This commit is contained in:
yanqingmen
2015-06-10 23:34:52 -07:00
parent 8c5d3ac130
commit 4e8a1c6516
14 changed files with 136 additions and 129 deletions

View File

@@ -18,13 +18,19 @@ package org.dmlc.xgboost4j.demo;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.AbstractMap;
import java.util.AbstractMap.SimpleEntry;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.DataLoader;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* a simple example of java wrapper for xgboost
@@ -51,8 +57,32 @@ public class BasicWalkThrough {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
//note: any Iterable<Entry<String, Object>> object would be used as paramters
//e.g.
// Map<String, Object> paramMap = new HashMap<String, Object>() {
// {
// put("eta", 1.0);
// put("max_depth", 2);
// put("silent", 1);
// put("objective", "binary:logistic");
// }
// };
// Iterable<Entry<String, Object>> param = paramMap.entrySet();
//or
// List<Entry<String, Object>> param = new ArrayList<Entry<String, Object>>() {
// {
// add(new SimpleEntry<String, Object>("eta", 1.0));
// add(new SimpleEntry<String, Object>("max_depth", 2.0));
// add(new SimpleEntry<String, Object>("silent", 1));
// add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
// }
// };
//we use a util class Params to handle parameters as example
Iterable<Entry<String, Object>> param = new Params() {
{
put("eta", 1.0);
put("max_depth", 2);
@@ -61,10 +91,21 @@ public class BasicWalkThrough {
}
};
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
//specify watchList to set evaluation dmats
//note: any Iterable<Entry<String, DMatrix>> object would be used as watchList
//e.g.
//an entrySet of Map is good
// Map<String, DMatrix> watchMap = new HashMap<>();
// watchMap.put("train", trainMat);
// watchMap.put("test", testMat);
// Iterable<Entry<String, DMatrix>> watchs = watchMap.entrySet();
//we use a List of Entry<String, DMatrix> WatchList as example
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new SimpleEntry<>("train", trainMat));
watchs.add(new SimpleEntry<>("test", testMat));
//set round
int round = 2;
@@ -110,9 +151,9 @@ public class BasicWalkThrough {
trainMat2.setLabel(spData.labels);
//specify watchList
WatchList watchs2 = new WatchList();
watchs2.put("train", trainMat2);
watchs2.put("test", testMat);
List<Entry<String, DMatrix>> watchs2 = new ArrayList<>();
watchs2.add(new SimpleEntry<>("train", trainMat2));
watchs2.add(new SimpleEntry<>("test", testMat2));
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
float[][] predicts3 = booster3.predict(testMat2);

View File

@@ -15,11 +15,14 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* example for start from a initial base prediction
@@ -43,10 +46,10 @@ public class BoostFromPrediction {
}
};
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
//specify watchList
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train xgboost for 1 round
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);

View File

@@ -18,7 +18,7 @@ package org.dmlc.xgboost4j.demo;
import java.io.IOException;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
/**
* an example of cross validation

View File

@@ -15,15 +15,16 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* an example user define objective and eval
@@ -140,9 +141,9 @@ public class CustomObjective {
int round = 2;
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//user define obj and eval
IObjective obj = new LogRegObj();

View File

@@ -15,11 +15,14 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* simple example for using external memory version
@@ -48,9 +51,9 @@ public class ExternalMemory {
//param.put("nthread", num_real_cpu);
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//set round
int round = 2;

View File

@@ -15,12 +15,15 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
/**
* this is an example of fit generalized linear model in xgboost
@@ -54,9 +57,9 @@ public class GeneralizedLinearModel {
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 4;

View File

@@ -15,13 +15,16 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.util.WatchList;
import org.dmlc.xgboost4j.demo.util.Params;
/**
* predict first ntree
@@ -44,9 +47,9 @@ public class PredictFirstNtree {
};
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 3;

View File

@@ -15,12 +15,15 @@
*/
package org.dmlc.xgboost4j.demo;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.WatchList;
import org.dmlc.xgboost4j.demo.util.Params;
/**
* predict leaf indices
@@ -43,9 +46,9 @@ public class PredictLeafIndices {
};
//specify watchList
WatchList watchs = new WatchList();
watchs.put("train", trainMat);
watchs.put("test", testMat);
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
//train a booster
int round = 3;

View File

@@ -0,0 +1,54 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.AbstractMap;
/**
* a util class for handle params
* @author hzx
*/
public class Params implements Iterable<Entry<String, Object>>{
List<Entry<String, Object>> params = new ArrayList<>();
/**
* put param key-value pair
* @param key
* @param value
*/
public void put(String key, Object value) {
params.add(new AbstractMap.SimpleEntry<>(key, value));
}
@Override
public String toString(){
String paramsInfo = "";
for(Entry<String, Object> param : params) {
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
}
return paramsInfo;
}
@Override
public Iterator<Entry<String, Object>> iterator() {
return params.iterator();
}
}