rm WatchList class, take Iterable<Entry<String, DMatrix>> as eval param, change Params to Iterable<Entry<String, Object>>
This commit is contained in:
@@ -18,13 +18,19 @@ package org.dmlc.xgboost4j.demo;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.util.AbstractMap;
|
||||
import java.util.AbstractMap.SimpleEntry;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* a simple example of java wrapper for xgboost
|
||||
@@ -51,8 +57,32 @@ public class BasicWalkThrough {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
|
||||
//specify parameters
|
||||
Params param = new Params() {
|
||||
//note: any Iterable<Entry<String, Object>> object would be used as paramters
|
||||
//e.g.
|
||||
// Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
// {
|
||||
// put("eta", 1.0);
|
||||
// put("max_depth", 2);
|
||||
// put("silent", 1);
|
||||
// put("objective", "binary:logistic");
|
||||
// }
|
||||
// };
|
||||
// Iterable<Entry<String, Object>> param = paramMap.entrySet();
|
||||
|
||||
//or
|
||||
// List<Entry<String, Object>> param = new ArrayList<Entry<String, Object>>() {
|
||||
// {
|
||||
// add(new SimpleEntry<String, Object>("eta", 1.0));
|
||||
// add(new SimpleEntry<String, Object>("max_depth", 2.0));
|
||||
// add(new SimpleEntry<String, Object>("silent", 1));
|
||||
// add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
|
||||
// }
|
||||
// };
|
||||
|
||||
//we use a util class Params to handle parameters as example
|
||||
Iterable<Entry<String, Object>> param = new Params() {
|
||||
{
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
@@ -61,10 +91,21 @@ public class BasicWalkThrough {
|
||||
}
|
||||
};
|
||||
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
|
||||
|
||||
//specify watchList to set evaluation dmats
|
||||
//note: any Iterable<Entry<String, DMatrix>> object would be used as watchList
|
||||
//e.g.
|
||||
//an entrySet of Map is good
|
||||
// Map<String, DMatrix> watchMap = new HashMap<>();
|
||||
// watchMap.put("train", trainMat);
|
||||
// watchMap.put("test", testMat);
|
||||
// Iterable<Entry<String, DMatrix>> watchs = watchMap.entrySet();
|
||||
|
||||
//we use a List of Entry<String, DMatrix> WatchList as example
|
||||
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||
watchs.add(new SimpleEntry<>("train", trainMat));
|
||||
watchs.add(new SimpleEntry<>("test", testMat));
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
@@ -110,9 +151,9 @@ public class BasicWalkThrough {
|
||||
trainMat2.setLabel(spData.labels);
|
||||
|
||||
//specify watchList
|
||||
WatchList watchs2 = new WatchList();
|
||||
watchs2.put("train", trainMat2);
|
||||
watchs2.put("test", testMat);
|
||||
List<Entry<String, DMatrix>> watchs2 = new ArrayList<>();
|
||||
watchs2.add(new SimpleEntry<>("train", trainMat2));
|
||||
watchs2.add(new SimpleEntry<>("test", testMat2));
|
||||
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
|
||||
float[][] predicts3 = booster3.predict(testMat2);
|
||||
|
||||
|
||||
@@ -15,11 +15,14 @@
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* example for start from a initial base prediction
|
||||
@@ -43,10 +46,10 @@ public class BoostFromPrediction {
|
||||
}
|
||||
};
|
||||
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
//specify watchList
|
||||
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||
|
||||
//train xgboost for 1 round
|
||||
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);
|
||||
|
||||
@@ -18,7 +18,7 @@ package org.dmlc.xgboost4j.demo;
|
||||
import java.io.IOException;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
|
||||
/**
|
||||
* an example of cross validation
|
||||
|
||||
@@ -15,15 +15,16 @@
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.IObjective;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* an example user define objective and eval
|
||||
@@ -140,9 +141,9 @@ public class CustomObjective {
|
||||
int round = 2;
|
||||
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||
|
||||
//user define obj and eval
|
||||
IObjective obj = new LogRegObj();
|
||||
|
||||
@@ -15,11 +15,14 @@
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* simple example for using external memory version
|
||||
@@ -48,9 +51,9 @@ public class ExternalMemory {
|
||||
//param.put("nthread", num_real_cpu);
|
||||
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
@@ -15,12 +15,15 @@
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
@@ -54,9 +57,9 @@ public class GeneralizedLinearModel {
|
||||
|
||||
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||
|
||||
//train a booster
|
||||
int round = 4;
|
||||
|
||||
@@ -15,13 +15,16 @@
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
|
||||
/**
|
||||
* predict first ntree
|
||||
@@ -44,9 +47,9 @@ public class PredictFirstNtree {
|
||||
};
|
||||
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
|
||||
@@ -15,12 +15,15 @@
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.WatchList;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
|
||||
/**
|
||||
* predict leaf indices
|
||||
@@ -43,9 +46,9 @@ public class PredictLeafIndices {
|
||||
};
|
||||
|
||||
//specify watchList
|
||||
WatchList watchs = new WatchList();
|
||||
watchs.put("train", trainMat);
|
||||
watchs.put("test", testMat);
|
||||
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.AbstractMap;
|
||||
|
||||
|
||||
/**
|
||||
* a util class for handle params
|
||||
* @author hzx
|
||||
*/
|
||||
public class Params implements Iterable<Entry<String, Object>>{
|
||||
List<Entry<String, Object>> params = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* put param key-value pair
|
||||
* @param key
|
||||
* @param value
|
||||
*/
|
||||
public void put(String key, Object value) {
|
||||
params.add(new AbstractMap.SimpleEntry<>(key, value));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
String paramsInfo = "";
|
||||
for(Entry<String, Object> param : params) {
|
||||
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
|
||||
}
|
||||
return paramsInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Entry<String, Object>> iterator() {
|
||||
return params.iterator();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user