rm WatchList class, take Iterable<Entry<String, DMatrix>> as eval param, change Params to Iterable<Entry<String, Object>>
This commit is contained in:
@@ -25,11 +25,11 @@ import java.io.UnsupportedEncodingException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ public final class Booster {
|
||||
* @param params parameters
|
||||
* @param dMatrixs DMatrix array
|
||||
*/
|
||||
public Booster(Params params, DMatrix[] dMatrixs) {
|
||||
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) {
|
||||
init(dMatrixs);
|
||||
setParam("seed","0");
|
||||
setParams(params);
|
||||
@@ -71,7 +71,7 @@ public final class Booster {
|
||||
* @param params parameters
|
||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||
*/
|
||||
public Booster(Params params, String modelPath) {
|
||||
public Booster(Iterable<Entry<String, Object>> params, String modelPath) {
|
||||
handle = XgboostJNI.XGBoosterCreate(new long[] {});
|
||||
loadModel(modelPath);
|
||||
setParam("seed","0");
|
||||
@@ -102,7 +102,7 @@ public final class Booster {
|
||||
* set parameters
|
||||
* @param params parameters key-value map
|
||||
*/
|
||||
public void setParams(Params params) {
|
||||
public void setParams(Iterable<Entry<String, Object>> params) {
|
||||
if(params!=null) {
|
||||
for(Map.Entry<String, Object> entry : params) {
|
||||
setParam(entry.getKey(), entry.getValue().toString());
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.util.Map;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
@@ -37,7 +38,7 @@ public class CVPack {
|
||||
* @param dtest test data
|
||||
* @param params parameters
|
||||
*/
|
||||
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
|
||||
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) {
|
||||
dmats = new DMatrix[] {dtrain, dtest};
|
||||
booster = new Booster(params, dmats);
|
||||
names = new String[] {"train", "test"};
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.AbstractMap;
|
||||
|
||||
|
||||
/**
|
||||
* a util class for handle params
|
||||
* @author hzx
|
||||
*/
|
||||
public class Params implements Iterable<Entry<String, Object>>{
|
||||
List<Entry<String, Object>> params = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* put param key-value pair
|
||||
* @param key
|
||||
* @param value
|
||||
*/
|
||||
public void put(String key, Object value) {
|
||||
params.add(new AbstractMap.SimpleEntry<>(key, value));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
String paramsInfo = "";
|
||||
for(Entry<String, Object> param : params) {
|
||||
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
|
||||
}
|
||||
return paramsInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Entry<String, Object>> iterator() {
|
||||
return params.iterator();
|
||||
}
|
||||
}
|
||||
@@ -46,21 +46,23 @@ public class Trainer {
|
||||
* @param eval customized evaluation (set to null if not used)
|
||||
* @return trained booster
|
||||
*/
|
||||
public static Booster train(Params params, DMatrix dtrain, int round,
|
||||
WatchList watchs, IObjective obj, IEvaluation eval) {
|
||||
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
|
||||
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) {
|
||||
|
||||
//collect eval matrixs
|
||||
int len = watchs.size();
|
||||
int i = 0;
|
||||
String[] evalNames = new String[len];
|
||||
DMatrix[] evalMats = new DMatrix[len];
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
List<String> names = new ArrayList<>();
|
||||
List<DMatrix> mats = new ArrayList<>();
|
||||
|
||||
for(Entry<String, DMatrix> evalEntry : watchs) {
|
||||
evalNames[i] = evalEntry.getKey();
|
||||
evalMats[i] = evalEntry.getValue();
|
||||
i++;
|
||||
names.add(evalEntry.getKey());
|
||||
mats.add(evalEntry.getValue());
|
||||
}
|
||||
|
||||
evalNames = names.toArray(new String[names.size()]);
|
||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
if(evalMats!=null && evalMats.length>0) {
|
||||
@@ -110,7 +112,7 @@ public class Trainer {
|
||||
* @param eval customized evaluation (set to null if not used)
|
||||
* @return evaluation history
|
||||
*/
|
||||
public static String[] crossValiation(Params params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
|
||||
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
|
||||
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
|
||||
String[] evalHist = new String[round];
|
||||
String[] results = new String[cvPacks.length];
|
||||
@@ -147,7 +149,7 @@ public class Trainer {
|
||||
* @param evalMetrics Evaluation metrics
|
||||
* @return CV package array
|
||||
*/
|
||||
public static CVPack[] makeNFold(DMatrix data, int nfold, Params params, String[] evalMetrics) {
|
||||
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) {
|
||||
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
|
||||
int step = samples.size()/nfold;
|
||||
int[] testSlice = new int[step];
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
|
||||
/**
|
||||
* class to handle evaluation dmatrix
|
||||
* @author hzx
|
||||
*/
|
||||
public class WatchList implements Iterable<Entry<String, DMatrix> >{
|
||||
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* put eval dmatrix and it's name
|
||||
* @param name
|
||||
* @param dmat
|
||||
*/
|
||||
public void put(String name, DMatrix dmat) {
|
||||
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return watchList.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Entry<String, DMatrix>> iterator() {
|
||||
return watchList.iterator();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user