rm WatchList class, take Iterable<Entry<String, DMatrix>> as eval param, change Params to Iterable<Entry<String, Object>>

This commit is contained in:
yanqingmen
2015-06-10 23:34:52 -07:00
parent 8c5d3ac130
commit 4e8a1c6516
14 changed files with 136 additions and 129 deletions

View File

@@ -25,11 +25,11 @@ import java.io.UnsupportedEncodingException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
@@ -58,7 +58,7 @@ public final class Booster {
* @param params parameters
* @param dMatrixs DMatrix array
*/
public Booster(Params params, DMatrix[] dMatrixs) {
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) {
init(dMatrixs);
setParam("seed","0");
setParams(params);
@@ -71,7 +71,7 @@ public final class Booster {
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
*/
public Booster(Params params, String modelPath) {
public Booster(Iterable<Entry<String, Object>> params, String modelPath) {
handle = XgboostJNI.XGBoosterCreate(new long[] {});
loadModel(modelPath);
setParam("seed","0");
@@ -102,7 +102,7 @@ public final class Booster {
* set parameters
* @param params parameters key-value map
*/
public void setParams(Params params) {
public void setParams(Iterable<Entry<String, Object>> params) {
if(params!=null) {
for(Map.Entry<String, Object> entry : params) {
setParam(entry.getKey(), entry.getValue().toString());

View File

@@ -15,6 +15,7 @@
*/
package org.dmlc.xgboost4j.util;
import java.util.Map;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
@@ -37,7 +38,7 @@ public class CVPack {
* @param dtest test data
* @param params parameters
*/
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) {
dmats = new DMatrix[] {dtrain, dtest};
booster = new Booster(params, dmats);
names = new String[] {"train", "test"};

View File

@@ -1,54 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.AbstractMap;
/**
* a util class for handle params
* @author hzx
*/
public class Params implements Iterable<Entry<String, Object>>{
List<Entry<String, Object>> params = new ArrayList<>();
/**
* put param key-value pair
* @param key
* @param value
*/
public void put(String key, Object value) {
params.add(new AbstractMap.SimpleEntry<>(key, value));
}
@Override
public String toString(){
String paramsInfo = "";
for(Entry<String, Object> param : params) {
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
}
return paramsInfo;
}
@Override
public Iterator<Entry<String, Object>> iterator() {
return params.iterator();
}
}

View File

@@ -46,21 +46,23 @@ public class Trainer {
* @param eval customized evaluation (set to null if not used)
* @return trained booster
*/
public static Booster train(Params params, DMatrix dtrain, int round,
WatchList watchs, IObjective obj, IEvaluation eval) {
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) {
//collect eval matrixs
int len = watchs.size();
int i = 0;
String[] evalNames = new String[len];
DMatrix[] evalMats = new DMatrix[len];
String[] evalNames;
DMatrix[] evalMats;
List<String> names = new ArrayList<>();
List<DMatrix> mats = new ArrayList<>();
for(Entry<String, DMatrix> evalEntry : watchs) {
evalNames[i] = evalEntry.getKey();
evalMats[i] = evalEntry.getValue();
i++;
names.add(evalEntry.getKey());
mats.add(evalEntry.getValue());
}
evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
//collect all data matrixs
DMatrix[] allMats;
if(evalMats!=null && evalMats.length>0) {
@@ -110,7 +112,7 @@ public class Trainer {
* @param eval customized evaluation (set to null if not used)
* @return evaluation history
*/
public static String[] crossValiation(Params params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
String[] evalHist = new String[round];
String[] results = new String[cvPacks.length];
@@ -147,7 +149,7 @@ public class Trainer {
* @param evalMetrics Evaluation metrics
* @return CV package array
*/
public static CVPack[] makeNFold(DMatrix data, int nfold, Params params, String[] evalMetrics) {
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) {
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
int step = samples.size()/nfold;
int[] testSlice = new int[step];

View File

@@ -1,49 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import org.dmlc.xgboost4j.DMatrix;
/**
* class to handle evaluation dmatrix
* @author hzx
*/
public class WatchList implements Iterable<Entry<String, DMatrix> >{
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
/**
* put eval dmatrix and it's name
* @param name
* @param dmat
*/
public void put(String name, DMatrix dmat) {
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
}
public int size() {
return watchList.size();
}
@Override
public Iterator<Entry<String, DMatrix>> iterator() {
return watchList.iterator();
}
}