make some fix
This commit is contained in:
@@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.util.Params;
|
||||
import org.dmlc.xgboost4j.util.TransferUtil;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
|
||||
@@ -85,7 +84,7 @@ public final class Booster {
|
||||
private void init(DMatrix[] dMatrixs) {
|
||||
long[] handles = null;
|
||||
if(dMatrixs != null) {
|
||||
handles = TransferUtil.dMatrixs2handles(dMatrixs);
|
||||
handles = dMatrixs2handles(dMatrixs);
|
||||
}
|
||||
handle = XgboostJNI.XGBoosterCreate(handles);
|
||||
}
|
||||
@@ -105,8 +104,8 @@ public final class Booster {
|
||||
*/
|
||||
public void setParams(Params params) {
|
||||
if(params!=null) {
|
||||
for(Map.Entry<String, String> entry : params) {
|
||||
setParam(entry.getKey(), entry.getValue());
|
||||
for(Map.Entry<String, Object> entry : params) {
|
||||
setParam(entry.getKey(), entry.getValue().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -154,7 +153,7 @@ public final class Booster {
|
||||
* @return eval information
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
|
||||
long[] handles = TransferUtil.dMatrixs2handles(evalMatrixs);
|
||||
long[] handles = dMatrixs2handles(evalMatrixs);
|
||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
|
||||
return evalInfo;
|
||||
}
|
||||
@@ -424,6 +423,19 @@ public final class Booster {
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for(int i=0; i<dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
}
|
||||
return handles;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
delete();
|
||||
|
||||
@@ -19,7 +19,6 @@ import java.io.IOException;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.util.TransferUtil;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
/**
|
||||
@@ -126,7 +125,7 @@ public class DMatrix {
|
||||
* @param baseMargin
|
||||
*/
|
||||
public void setBaseMargin(float[][] baseMargin) {
|
||||
float[] flattenMargin = TransferUtil.flatten(baseMargin);
|
||||
float[] flattenMargin = flatten(baseMargin);
|
||||
setBaseMargin(flattenMargin);
|
||||
}
|
||||
|
||||
@@ -203,6 +202,24 @@ public class DMatrix {
|
||||
return handle;
|
||||
}
|
||||
|
||||
/**
|
||||
* flatten a mat to array
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
private static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
for (float[] array : mat) size += array.length;
|
||||
float[] result = new float[size];
|
||||
int pos = 0;
|
||||
for (float[] ar : mat) {
|
||||
System.arraycopy(ar, 0, result, pos, ar.length);
|
||||
pos += ar.length;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
delete();
|
||||
|
||||
@@ -28,7 +28,6 @@ public class CVPack {
|
||||
DMatrix dtrain;
|
||||
DMatrix dtest;
|
||||
DMatrix[] dmats;
|
||||
long[] dataArray;
|
||||
String[] names;
|
||||
Booster booster;
|
||||
|
||||
@@ -41,7 +40,6 @@ public class CVPack {
|
||||
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
|
||||
dmats = new DMatrix[] {dtrain, dtest};
|
||||
booster = new Booster(params, dmats);
|
||||
dataArray = TransferUtil.dMatrixs2handles(dmats);
|
||||
names = new String[] {"train", "test"};
|
||||
this.dtrain = dtrain;
|
||||
this.dtest = dtest;
|
||||
@@ -70,7 +68,7 @@ public class CVPack {
|
||||
* @return
|
||||
*/
|
||||
public String eval(int iter) {
|
||||
return booster.evalSet(dataArray, names, iter);
|
||||
return booster.evalSet(dmats, names, iter);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -43,7 +43,7 @@ public class Initializer {
|
||||
}
|
||||
|
||||
/**
|
||||
* load native library, this method will first try to load library from java.library.path, then try to load from library in jar package.
|
||||
* load native library, this method will first try to load library from java.library.path, then try to load library in jar package.
|
||||
* @param libName
|
||||
* @throws IOException
|
||||
*/
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
/*
|
||||
* To change this license header, choose License Headers in Project Properties.
|
||||
* To change this template file, choose Tools | Templates
|
||||
* and open the template in the editor.
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
|
||||
@@ -26,29 +26,29 @@ import java.util.AbstractMap;
|
||||
* a util class for handle params
|
||||
* @author hzx
|
||||
*/
|
||||
public class Params implements Iterable<Entry<String, String>>{
|
||||
List<Entry<String, String>> params = new ArrayList<>();
|
||||
public class Params implements Iterable<Entry<String, Object>>{
|
||||
List<Entry<String, Object>> params = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* put param key-value pair
|
||||
* @param key
|
||||
* @param value
|
||||
*/
|
||||
public void put(String key, String value) {
|
||||
public void put(String key, Object value) {
|
||||
params.add(new AbstractMap.SimpleEntry<>(key, value));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
String paramsInfo = "";
|
||||
for(Entry<String, String> param : params) {
|
||||
for(Entry<String, Object> param : params) {
|
||||
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
|
||||
}
|
||||
return paramsInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Entry<String, String>> iterator() {
|
||||
public Iterator<Entry<String, Object>> iterator() {
|
||||
return params.iterator();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
@@ -40,14 +41,26 @@ public class Trainer {
|
||||
* @param params Booster params.
|
||||
* @param dtrain Data to be trained.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param evalMats Data to be evaluated (may include dtrain)
|
||||
* @param evalNames name of data (used for evaluation info)
|
||||
* @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set.
|
||||
* @param obj customized objective (set to null if not used)
|
||||
* @param eval customized evaluation (set to null if not used)
|
||||
* @return trained booster
|
||||
*/
|
||||
public static Booster train(Params params, DMatrix dtrain, int round,
|
||||
DMatrix[] evalMats, String[] evalNames, IObjective obj, IEvaluation eval) {
|
||||
public static Booster train(Params params, DMatrix dtrain, int round,
|
||||
WatchList watchs, IObjective obj, IEvaluation eval) {
|
||||
|
||||
//collect eval matrixs
|
||||
int len = watchs.size();
|
||||
int i = 0;
|
||||
String[] evalNames = new String[len];
|
||||
DMatrix[] evalMats = new DMatrix[len];
|
||||
|
||||
for(Entry<String, DMatrix> evalEntry : watchs) {
|
||||
evalNames[i] = evalEntry.getKey();
|
||||
evalMats[i] = evalEntry.getValue();
|
||||
i++;
|
||||
}
|
||||
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
if(evalMats!=null && evalMats.length>0) {
|
||||
@@ -63,16 +76,6 @@ public class Trainer {
|
||||
//initialize booster
|
||||
Booster booster = new Booster(params, allMats);
|
||||
|
||||
//used for evaluation
|
||||
long[] dataArray = null;
|
||||
String[] names = null;
|
||||
|
||||
if(dataArray==null || names==null) {
|
||||
//prepare data for evaluation
|
||||
dataArray = TransferUtil.dMatrixs2handles(evalMats);
|
||||
names = evalNames;
|
||||
}
|
||||
|
||||
//begin to train
|
||||
for(int iter=0; iter<round; iter++) {
|
||||
if(obj != null) {
|
||||
@@ -88,7 +91,7 @@ public class Trainer {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
|
||||
}
|
||||
else {
|
||||
evalInfo = booster.evalSet(dataArray, names, iter);
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||
}
|
||||
logger.info(evalInfo);
|
||||
}
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class TransferUtil {
|
||||
/**
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
public static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for(int i=0; i<dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
}
|
||||
return handles;
|
||||
}
|
||||
|
||||
/**
|
||||
* flatten a mat to array
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
public static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
for (float[] array : mat) size += array.length;
|
||||
float[] result = new float[size];
|
||||
int pos = 0;
|
||||
for (float[] ar : mat) {
|
||||
System.arraycopy(ar, 0, result, pos, ar.length);
|
||||
pos += ar.length;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
|
||||
/**
|
||||
* class to handle evaluation dmatrix
|
||||
* @author hzx
|
||||
*/
|
||||
public class WatchList implements Iterable<Entry<String, DMatrix> >{
|
||||
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* put eval dmatrix and it's name
|
||||
* @param name
|
||||
* @param dmat
|
||||
*/
|
||||
public void put(String name, DMatrix dmat) {
|
||||
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return watchList.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Entry<String, DMatrix>> iterator() {
|
||||
return watchList.iterator();
|
||||
}
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
please put native library in this package.
|
||||
Reference in New Issue
Block a user