make some fix

This commit is contained in:
yanqingmen
2015-06-10 20:09:49 -07:00
parent f91a098770
commit c110111f52
22 changed files with 234 additions and 162 deletions

View File

@@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.TransferUtil;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
@@ -85,7 +84,7 @@ public final class Booster {
private void init(DMatrix[] dMatrixs) {
long[] handles = null;
if(dMatrixs != null) {
handles = TransferUtil.dMatrixs2handles(dMatrixs);
handles = dMatrixs2handles(dMatrixs);
}
handle = XgboostJNI.XGBoosterCreate(handles);
}
@@ -105,8 +104,8 @@ public final class Booster {
*/
public void setParams(Params params) {
if(params!=null) {
for(Map.Entry<String, String> entry : params) {
setParam(entry.getKey(), entry.getValue());
for(Map.Entry<String, Object> entry : params) {
setParam(entry.getKey(), entry.getValue().toString());
}
}
}
@@ -154,7 +153,7 @@ public final class Booster {
* @return eval information
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
long[] handles = TransferUtil.dMatrixs2handles(evalMatrixs);
long[] handles = dMatrixs2handles(evalMatrixs);
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
return evalInfo;
}
@@ -424,6 +423,19 @@ public final class Booster {
return featureScore;
}
/**
* transfer DMatrix array to handle array (used for native functions)
* @param dmatrixs
* @return handle array for input dmatrixs
*/
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for(int i=0; i<dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
@Override
protected void finalize() {
delete();

View File

@@ -19,7 +19,6 @@ import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.TransferUtil;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
/**
@@ -126,7 +125,7 @@ public class DMatrix {
* @param baseMargin
*/
public void setBaseMargin(float[][] baseMargin) {
float[] flattenMargin = TransferUtil.flatten(baseMargin);
float[] flattenMargin = flatten(baseMargin);
setBaseMargin(flattenMargin);
}
@@ -203,6 +202,24 @@ public class DMatrix {
return handle;
}
/**
* flatten a mat to array
* @param mat
* @return
*/
private static float[] flatten(float[][] mat) {
int size = 0;
for (float[] array : mat) size += array.length;
float[] result = new float[size];
int pos = 0;
for (float[] ar : mat) {
System.arraycopy(ar, 0, result, pos, ar.length);
pos += ar.length;
}
return result;
}
@Override
protected void finalize() {
delete();

View File

@@ -28,7 +28,6 @@ public class CVPack {
DMatrix dtrain;
DMatrix dtest;
DMatrix[] dmats;
long[] dataArray;
String[] names;
Booster booster;
@@ -41,7 +40,6 @@ public class CVPack {
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
dmats = new DMatrix[] {dtrain, dtest};
booster = new Booster(params, dmats);
dataArray = TransferUtil.dMatrixs2handles(dmats);
names = new String[] {"train", "test"};
this.dtrain = dtrain;
this.dtest = dtest;
@@ -70,7 +68,7 @@ public class CVPack {
* @return
*/
public String eval(int iter) {
return booster.evalSet(dataArray, names, iter);
return booster.evalSet(dmats, names, iter);
}
/**

View File

@@ -43,7 +43,7 @@ public class Initializer {
}
/**
* load native library, this method will first try to load library from java.library.path, then try to load from library in jar package.
* load native library, this method will first try to load library from java.library.path, then try to load library in jar package.
* @param libName
* @throws IOException
*/

View File

@@ -1,7 +1,17 @@
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;

View File

@@ -26,29 +26,29 @@ import java.util.AbstractMap;
* a util class for handle params
* @author hzx
*/
public class Params implements Iterable<Entry<String, String>>{
List<Entry<String, String>> params = new ArrayList<>();
public class Params implements Iterable<Entry<String, Object>>{
List<Entry<String, Object>> params = new ArrayList<>();
/**
* put param key-value pair
* @param key
* @param value
*/
public void put(String key, String value) {
public void put(String key, Object value) {
params.add(new AbstractMap.SimpleEntry<>(key, value));
}
@Override
public String toString(){
String paramsInfo = "";
for(Entry<String, String> param : params) {
for(Entry<String, Object> param : params) {
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
}
return paramsInfo;
}
@Override
public Iterator<Entry<String, String>> iterator() {
public Iterator<Entry<String, Object>> iterator() {
return params.iterator();
}
}

View File

@@ -20,6 +20,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.IEvaluation;
@@ -40,14 +41,26 @@ public class Trainer {
* @param params Booster params.
* @param dtrain Data to be trained.
* @param round Number of boosting iterations.
* @param evalMats Data to be evaluated (may include dtrain)
* @param evalNames name of data (used for evaluation info)
* @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set.
* @param obj customized objective (set to null if not used)
* @param eval customized evaluation (set to null if not used)
* @return trained booster
*/
public static Booster train(Params params, DMatrix dtrain, int round,
DMatrix[] evalMats, String[] evalNames, IObjective obj, IEvaluation eval) {
public static Booster train(Params params, DMatrix dtrain, int round,
WatchList watchs, IObjective obj, IEvaluation eval) {
//collect eval matrixs
int len = watchs.size();
int i = 0;
String[] evalNames = new String[len];
DMatrix[] evalMats = new DMatrix[len];
for(Entry<String, DMatrix> evalEntry : watchs) {
evalNames[i] = evalEntry.getKey();
evalMats[i] = evalEntry.getValue();
i++;
}
//collect all data matrixs
DMatrix[] allMats;
if(evalMats!=null && evalMats.length>0) {
@@ -63,16 +76,6 @@ public class Trainer {
//initialize booster
Booster booster = new Booster(params, allMats);
//used for evaluation
long[] dataArray = null;
String[] names = null;
if(dataArray==null || names==null) {
//prepare data for evaluation
dataArray = TransferUtil.dMatrixs2handles(evalMats);
names = evalNames;
}
//begin to train
for(int iter=0; iter<round; iter++) {
if(obj != null) {
@@ -88,7 +91,7 @@ public class Trainer {
evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
}
else {
evalInfo = booster.evalSet(dataArray, names, iter);
evalInfo = booster.evalSet(evalMats, evalNames, iter);
}
logger.info(evalInfo);
}

View File

@@ -1,55 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import org.dmlc.xgboost4j.DMatrix;
/**
*
* @author hzx
*/
public class TransferUtil {
/**
* transfer DMatrix array to handle array (used for native functions)
* @param dmatrixs
* @return handle array for input dmatrixs
*/
public static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for(int i=0; i<dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
/**
* flatten a mat to array
* @param mat
* @return
*/
public static float[] flatten(float[][] mat) {
int size = 0;
for (float[] array : mat) size += array.length;
float[] result = new float[size];
int pos = 0;
for (float[] ar : mat) {
System.arraycopy(ar, 0, result, pos, ar.length);
pos += ar.length;
}
return result;
}
}

View File

@@ -0,0 +1,49 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import org.dmlc.xgboost4j.DMatrix;
/**
* class to handle evaluation dmatrix
* @author hzx
*/
public class WatchList implements Iterable<Entry<String, DMatrix> >{
List<Entry<String, DMatrix>> watchList = new ArrayList<>();
/**
* put eval dmatrix and it's name
* @param name
* @param dmat
*/
public void put(String name, DMatrix dmat) {
watchList.add(new AbstractMap.SimpleEntry<>(name, dmat));
}
public int size() {
return watchList.size();
}
@Override
public Iterator<Entry<String, DMatrix>> iterator() {
return watchList.iterator();
}
}

View File

@@ -1 +0,0 @@
please put native library in this package.