rm WatchList class, take Iterable<Entry<String, DMatrix>> as eval param, change Params to Iterable<Entry<String, Object>>
This commit is contained in:
parent
8c5d3ac130
commit
4e8a1c6516
@ -73,14 +73,11 @@ dmat.setWeight(weights);
|
|||||||
```
|
```
|
||||||
|
|
||||||
#### Setting Parameters
|
#### Setting Parameters
|
||||||
* A util class ```Params``` in xgboost4j is used to handle parameters.
|
* in xgboost4j any ```Iterable<Entry<String, Object>>``` object could be used as parameters.
|
||||||
* To import ```Params``` :
|
|
||||||
|
* to set parameters, for non-multiple value params, you can simply use entrySet of an Map:
|
||||||
```java
|
```java
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
Map<String, Object> paramMap = new HashMap<>() {
|
||||||
```
|
|
||||||
* to set parameters :
|
|
||||||
```java
|
|
||||||
Params params = new Params() {
|
|
||||||
{
|
{
|
||||||
put("eta", 1.0);
|
put("eta", 1.0);
|
||||||
put("max_depth", 2);
|
put("max_depth", 2);
|
||||||
@ -89,18 +86,17 @@ Params params = new Params() {
|
|||||||
put("eval_metric", "logloss");
|
put("eval_metric", "logloss");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Iterable<Entry<String, Object>> params = paramMap.entrySet();
|
||||||
```
|
```
|
||||||
* Multiple values with same param key is handled naturally in ```Params```, e.g. :
|
* for the situation that multiple values with same param key, List<Entry<String, Object>> would be a good choice, e.g. :
|
||||||
```java
|
```java
|
||||||
Params params = new Params() {
|
List<Entry<String, Object>> params = new ArrayList<Entry<String, Object>>() {
|
||||||
{
|
{
|
||||||
put("eta", 1.0);
|
add(new SimpleEntry<String, Object>("eta", 1.0));
|
||||||
put("max_depth", 2);
|
add(new SimpleEntry<String, Object>("max_depth", 2.0));
|
||||||
put("silent", 1);
|
add(new SimpleEntry<String, Object>("silent", 1));
|
||||||
put("objective", "binary:logistic");
|
add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
|
||||||
put("eval_metric", "logloss");
|
}
|
||||||
put("eval_metric", "error");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -110,7 +106,6 @@ With parameters and data, you are able to train a booster model.
|
|||||||
```java
|
```java
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.util.WatchList;
|
|
||||||
```
|
```
|
||||||
|
|
||||||
* Training
|
* Training
|
||||||
@ -118,9 +113,10 @@ import org.dmlc.xgboost4j.util.WatchList;
|
|||||||
DMatrix trainMat = new DMatrix("train.svm.txt");
|
DMatrix trainMat = new DMatrix("train.svm.txt");
|
||||||
DMatrix validMat = new DMatrix("valid.svm.txt");
|
DMatrix validMat = new DMatrix("valid.svm.txt");
|
||||||
//specifiy a watchList to see the performance
|
//specifiy a watchList to see the performance
|
||||||
WatchList watchs = new WatchList();
|
//any Iterable<Entry<String, DMatrix>> object could be used as watchList
|
||||||
watchs.put("train", trainMat);
|
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
watchs.put("test", testMat);
|
watchs.add(new SimpleEntry<>("train", trainMat));
|
||||||
|
watchs.add(new SimpleEntry<>("test", testMat));
|
||||||
int round = 2;
|
int round = 2;
|
||||||
Booster booster = Trainer.train(params, trainMat, round, watchs, null, null);
|
Booster booster = Trainer.train(params, trainMat, round, watchs, null, null);
|
||||||
```
|
```
|
||||||
|
|||||||
@ -18,13 +18,19 @@ package org.dmlc.xgboost4j.demo;
|
|||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.UnsupportedEncodingException;
|
import java.io.UnsupportedEncodingException;
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.AbstractMap.SimpleEntry;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.util.WatchList;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* a simple example of java wrapper for xgboost
|
* a simple example of java wrapper for xgboost
|
||||||
@ -51,8 +57,32 @@ public class BasicWalkThrough {
|
|||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
|
||||||
//specify parameters
|
//specify parameters
|
||||||
Params param = new Params() {
|
//note: any Iterable<Entry<String, Object>> object would be used as paramters
|
||||||
|
//e.g.
|
||||||
|
// Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
// {
|
||||||
|
// put("eta", 1.0);
|
||||||
|
// put("max_depth", 2);
|
||||||
|
// put("silent", 1);
|
||||||
|
// put("objective", "binary:logistic");
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
// Iterable<Entry<String, Object>> param = paramMap.entrySet();
|
||||||
|
|
||||||
|
//or
|
||||||
|
// List<Entry<String, Object>> param = new ArrayList<Entry<String, Object>>() {
|
||||||
|
// {
|
||||||
|
// add(new SimpleEntry<String, Object>("eta", 1.0));
|
||||||
|
// add(new SimpleEntry<String, Object>("max_depth", 2.0));
|
||||||
|
// add(new SimpleEntry<String, Object>("silent", 1));
|
||||||
|
// add(new SimpleEntry<String, Object>("objective", "binary:logistic"));
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
|
//we use a util class Params to handle parameters as example
|
||||||
|
Iterable<Entry<String, Object>> param = new Params() {
|
||||||
{
|
{
|
||||||
put("eta", 1.0);
|
put("eta", 1.0);
|
||||||
put("max_depth", 2);
|
put("max_depth", 2);
|
||||||
@ -61,10 +91,21 @@ public class BasicWalkThrough {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//specify watchList
|
|
||||||
WatchList watchs = new WatchList();
|
|
||||||
watchs.put("train", trainMat);
|
//specify watchList to set evaluation dmats
|
||||||
watchs.put("test", testMat);
|
//note: any Iterable<Entry<String, DMatrix>> object would be used as watchList
|
||||||
|
//e.g.
|
||||||
|
//an entrySet of Map is good
|
||||||
|
// Map<String, DMatrix> watchMap = new HashMap<>();
|
||||||
|
// watchMap.put("train", trainMat);
|
||||||
|
// watchMap.put("test", testMat);
|
||||||
|
// Iterable<Entry<String, DMatrix>> watchs = watchMap.entrySet();
|
||||||
|
|
||||||
|
//we use a List of Entry<String, DMatrix> WatchList as example
|
||||||
|
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
|
watchs.add(new SimpleEntry<>("train", trainMat));
|
||||||
|
watchs.add(new SimpleEntry<>("test", testMat));
|
||||||
|
|
||||||
//set round
|
//set round
|
||||||
int round = 2;
|
int round = 2;
|
||||||
@ -110,9 +151,9 @@ public class BasicWalkThrough {
|
|||||||
trainMat2.setLabel(spData.labels);
|
trainMat2.setLabel(spData.labels);
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
WatchList watchs2 = new WatchList();
|
List<Entry<String, DMatrix>> watchs2 = new ArrayList<>();
|
||||||
watchs2.put("train", trainMat2);
|
watchs2.add(new SimpleEntry<>("train", trainMat2));
|
||||||
watchs2.put("test", testMat);
|
watchs2.add(new SimpleEntry<>("test", testMat2));
|
||||||
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
|
Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null);
|
||||||
float[][] predicts3 = booster3.predict(testMat2);
|
float[][] predicts3 = booster3.predict(testMat2);
|
||||||
|
|
||||||
|
|||||||
@ -15,11 +15,14 @@
|
|||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.demo;
|
package org.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.util.WatchList;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* example for start from a initial base prediction
|
* example for start from a initial base prediction
|
||||||
@ -43,10 +46,10 @@ public class BoostFromPrediction {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
WatchList watchs = new WatchList();
|
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
watchs.put("train", trainMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||||
watchs.put("test", testMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||||
|
|
||||||
//train xgboost for 1 round
|
//train xgboost for 1 round
|
||||||
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);
|
Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null);
|
||||||
|
|||||||
@ -18,7 +18,7 @@ package org.dmlc.xgboost4j.demo;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* an example of cross validation
|
* an example of cross validation
|
||||||
|
|||||||
@ -15,15 +15,16 @@
|
|||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.demo;
|
package org.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
|
import java.util.AbstractMap;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.IEvaluation;
|
import org.dmlc.xgboost4j.IEvaluation;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.IObjective;
|
import org.dmlc.xgboost4j.IObjective;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.util.WatchList;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* an example user define objective and eval
|
* an example user define objective and eval
|
||||||
@ -140,9 +141,9 @@ public class CustomObjective {
|
|||||||
int round = 2;
|
int round = 2;
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
WatchList watchs = new WatchList();
|
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
watchs.put("train", trainMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||||
watchs.put("test", testMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||||
|
|
||||||
//user define obj and eval
|
//user define obj and eval
|
||||||
IObjective obj = new LogRegObj();
|
IObjective obj = new LogRegObj();
|
||||||
|
|||||||
@ -15,11 +15,14 @@
|
|||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.demo;
|
package org.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.util.WatchList;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* simple example for using external memory version
|
* simple example for using external memory version
|
||||||
@ -48,9 +51,9 @@ public class ExternalMemory {
|
|||||||
//param.put("nthread", num_real_cpu);
|
//param.put("nthread", num_real_cpu);
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
WatchList watchs = new WatchList();
|
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
watchs.put("train", trainMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||||
watchs.put("test", testMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||||
|
|
||||||
//set round
|
//set round
|
||||||
int round = 2;
|
int round = 2;
|
||||||
|
|||||||
@ -15,12 +15,15 @@
|
|||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.demo;
|
package org.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.util.WatchList;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this is an example of fit generalized linear model in xgboost
|
* this is an example of fit generalized linear model in xgboost
|
||||||
@ -54,9 +57,9 @@ public class GeneralizedLinearModel {
|
|||||||
|
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
WatchList watchs = new WatchList();
|
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
watchs.put("train", trainMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||||
watchs.put("test", testMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||||
|
|
||||||
//train a booster
|
//train a booster
|
||||||
int round = 4;
|
int round = 4;
|
||||||
|
|||||||
@ -15,13 +15,16 @@
|
|||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.demo;
|
package org.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
|
||||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||||
import org.dmlc.xgboost4j.util.WatchList;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* predict first ntree
|
* predict first ntree
|
||||||
@ -44,9 +47,9 @@ public class PredictFirstNtree {
|
|||||||
};
|
};
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
WatchList watchs = new WatchList();
|
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
watchs.put("train", trainMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||||
watchs.put("test", testMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||||
|
|
||||||
//train a booster
|
//train a booster
|
||||||
int round = 3;
|
int round = 3;
|
||||||
|
|||||||
@ -15,12 +15,15 @@
|
|||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.demo;
|
package org.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.util.WatchList;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* predict leaf indices
|
* predict leaf indices
|
||||||
@ -43,9 +46,9 @@ public class PredictLeafIndices {
|
|||||||
};
|
};
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
WatchList watchs = new WatchList();
|
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
watchs.put("train", trainMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||||
watchs.put("test", testMat);
|
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||||
|
|
||||||
//train a booster
|
//train a booster
|
||||||
int round = 3;
|
int round = 3;
|
||||||
|
|||||||
@ -13,7 +13,7 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.util;
|
package org.dmlc.xgboost4j.demo.util;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
@ -25,11 +25,11 @@ import java.io.UnsupportedEncodingException;
|
|||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
import org.dmlc.xgboost4j.util.Initializer;
|
import org.dmlc.xgboost4j.util.Initializer;
|
||||||
import org.dmlc.xgboost4j.util.Params;
|
|
||||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ public final class Booster {
|
|||||||
* @param params parameters
|
* @param params parameters
|
||||||
* @param dMatrixs DMatrix array
|
* @param dMatrixs DMatrix array
|
||||||
*/
|
*/
|
||||||
public Booster(Params params, DMatrix[] dMatrixs) {
|
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) {
|
||||||
init(dMatrixs);
|
init(dMatrixs);
|
||||||
setParam("seed","0");
|
setParam("seed","0");
|
||||||
setParams(params);
|
setParams(params);
|
||||||
@ -71,7 +71,7 @@ public final class Booster {
|
|||||||
* @param params parameters
|
* @param params parameters
|
||||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||||
*/
|
*/
|
||||||
public Booster(Params params, String modelPath) {
|
public Booster(Iterable<Entry<String, Object>> params, String modelPath) {
|
||||||
handle = XgboostJNI.XGBoosterCreate(new long[] {});
|
handle = XgboostJNI.XGBoosterCreate(new long[] {});
|
||||||
loadModel(modelPath);
|
loadModel(modelPath);
|
||||||
setParam("seed","0");
|
setParam("seed","0");
|
||||||
@ -102,7 +102,7 @@ public final class Booster {
|
|||||||
* set parameters
|
* set parameters
|
||||||
* @param params parameters key-value map
|
* @param params parameters key-value map
|
||||||
*/
|
*/
|
||||||
public void setParams(Params params) {
|
public void setParams(Iterable<Entry<String, Object>> params) {
|
||||||
if(params!=null) {
|
if(params!=null) {
|
||||||
for(Map.Entry<String, Object> entry : params) {
|
for(Map.Entry<String, Object> entry : params) {
|
||||||
setParam(entry.getKey(), entry.getValue().toString());
|
setParam(entry.getKey(), entry.getValue().toString());
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.util;
|
package org.dmlc.xgboost4j.util;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import org.dmlc.xgboost4j.IEvaluation;
|
import org.dmlc.xgboost4j.IEvaluation;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
@ -37,7 +38,7 @@ public class CVPack {
|
|||||||
* @param dtest test data
|
* @param dtest test data
|
||||||
* @param params parameters
|
* @param params parameters
|
||||||
*/
|
*/
|
||||||
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
|
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) {
|
||||||
dmats = new DMatrix[] {dtrain, dtest};
|
dmats = new DMatrix[] {dtrain, dtest};
|
||||||
booster = new Booster(params, dmats);
|
booster = new Booster(params, dmats);
|
||||||
names = new String[] {"train", "test"};
|
names = new String[] {"train", "test"};
|
||||||
|
|||||||
@ -46,21 +46,23 @@ public class Trainer {
|
|||||||
* @param eval customized evaluation (set to null if not used)
|
* @param eval customized evaluation (set to null if not used)
|
||||||
* @return trained booster
|
* @return trained booster
|
||||||
*/
|
*/
|
||||||
public static Booster train(Params params, DMatrix dtrain, int round,
|
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
|
||||||
WatchList watchs, IObjective obj, IEvaluation eval) {
|
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) {
|
||||||
|
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
int len = watchs.size();
|
String[] evalNames;
|
||||||
int i = 0;
|
DMatrix[] evalMats;
|
||||||
String[] evalNames = new String[len];
|
List<String> names = new ArrayList<>();
|
||||||
DMatrix[] evalMats = new DMatrix[len];
|
List<DMatrix> mats = new ArrayList<>();
|
||||||
|
|
||||||
for(Entry<String, DMatrix> evalEntry : watchs) {
|
for(Entry<String, DMatrix> evalEntry : watchs) {
|
||||||
evalNames[i] = evalEntry.getKey();
|
names.add(evalEntry.getKey());
|
||||||
evalMats[i] = evalEntry.getValue();
|
mats.add(evalEntry.getValue());
|
||||||
i++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
evalNames = names.toArray(new String[names.size()]);
|
||||||
|
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||||
|
|
||||||
//collect all data matrixs
|
//collect all data matrixs
|
||||||
DMatrix[] allMats;
|
DMatrix[] allMats;
|
||||||
if(evalMats!=null && evalMats.length>0) {
|
if(evalMats!=null && evalMats.length>0) {
|
||||||
@ -110,7 +112,7 @@ public class Trainer {
|
|||||||
* @param eval customized evaluation (set to null if not used)
|
* @param eval customized evaluation (set to null if not used)
|
||||||
* @return evaluation history
|
* @return evaluation history
|
||||||
*/
|
*/
|
||||||
public static String[] crossValiation(Params params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
|
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
|
||||||
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
|
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
|
||||||
String[] evalHist = new String[round];
|
String[] evalHist = new String[round];
|
||||||
String[] results = new String[cvPacks.length];
|
String[] results = new String[cvPacks.length];
|
||||||
@ -147,7 +149,7 @@ public class Trainer {
|
|||||||
* @param evalMetrics Evaluation metrics
|
* @param evalMetrics Evaluation metrics
|
||||||
* @return CV package array
|
* @return CV package array
|
||||||
*/
|
*/
|
||||||
public static CVPack[] makeNFold(DMatrix data, int nfold, Params params, String[] evalMetrics) {
|
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) {
|
||||||
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
|
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
|
||||||
int step = samples.size()/nfold;
|
int step = samples.size()/nfold;
|
||||||
int[] testSlice = new int[step];
|
int[] testSlice = new int[step];
|
||||||
|
|||||||
@ -1,49 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
package org.dmlc.xgboost4j.util;
|
|
||||||
|
|
||||||
import java.util.AbstractMap;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map.Entry;
|
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* class to handle evaluation dmatrix
|
|
||||||
* @author hzx
|
|
||||||
*/
|
|
||||||
public class WatchList implements Iterable<Entry<String, DMatrix> >{
|
|
||||||
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* put eval dmatrix and it's name
|
|
||||||
* @param name
|
|
||||||
* @param dmat
|
|
||||||
*/
|
|
||||||
public void put(String name, DMatrix dmat) {
|
|
||||||
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
|
|
||||||
}
|
|
||||||
|
|
||||||
public int size() {
|
|
||||||
return watchList.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterator<Entry<String, DMatrix>> iterator() {
|
|
||||||
return watchList.iterator();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Loading…
x
Reference in New Issue
Block a user