Merge pull request #904 from tqchen/master

[JVM-PKG] Update JNI to include Rabit interface
This commit is contained in:
Tianqi Chen 2016-03-02 22:44:46 -08:00
commit 0f367a6ade
15 changed files with 688 additions and 287 deletions

5
.gitignore vendored
View File

@ -20,8 +20,9 @@
*buffer *buffer
*model *model
*pyc *pyc
*train *.train
*test *.test
*.tar
*group *group
*rar *rar
*vali *vali

View File

@ -73,7 +73,7 @@ endif
# specify tensor path # specify tensor path
.PHONY: clean all lint clean_all doxygen rcpplint Rpack Rbuild Rcheck java .PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
@ -143,6 +143,11 @@ clean_all: clean
doxygen: doxygen:
doxygen doc/Doxyfile doxygen doc/Doxyfile
# create standalone python tar file.
pypack: ${XGBOOST_DYLIB}
cp ${XGBOOST_DYLIB} python-package/xgboost
cd python-package; tar cf xgboost.tar xgboost; cd ..
# Script to make a clean installable R package. # Script to make a clean installable R package.
Rpack: Rpack:
make clean_all make clean_all

View File

@ -0,0 +1,5 @@
#!/bin/bash
# Simple script to test distributed version, to be deleted later.
cd xgboost4j-demo
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain
cd ..

View File

@ -0,0 +1,49 @@
package ml.dmlc.xgboost4j.demo;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Rabit;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
/**
* Distributed training example, used to quick test distributed training.
*
* @author tqchen
*/
public class DistTrain {
public static void main(String[] args) throws IOException, XGBoostError {
// always initialize rabit module before training.
Rabit.init(new HashMap<String, String>());
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
// always shutdown rabit module after training.
Rabit.shutdown();
}
}

View File

@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
import java.io.IOException; import java.io.IOException;
/** /**
* DMatrix for xgboost, similar to the python wrapper xgboost.py * DMatrix for xgboost.
* *
* @author hzx * @author hzx
*/ */

View File

@ -48,4 +48,5 @@ class JNIErrorHandle {
throw new XGBoostError(XgboostJNI.XGBGetLastError()); throw new XGBoostError(XgboostJNI.XGBGetLastError());
} }
} }
} }

View File

@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster {
return featureScore; return featureScore;
} }
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/** /**
* transfer DMatrix array to handle array (used for native functions) * transfer DMatrix array to handle array (used for native functions)
* *

View File

@ -0,0 +1,93 @@
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.Map;
import java.io.IOException;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
}
}
/**
* Initialize the rabit library on current working thread.
* @param envs The additional environment variables to pass to rabit.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
String[] args = new String[envs.size()];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
checkCall(XgboostJNI.RabitInit(args));
}
/**
* Shutdown the rabit engine in current working thread, equals to finalize.
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XgboostJNI.RabitFinalize());
}
/**
* Print the message on rabit tracker.
* @param msg
* @throws XGBoostError
*/
public static void trackerPrint(String msg) throws XGBoostError {
checkCall(XgboostJNI.RabitTrackerPrint(msg));
}
/**
* Get version number of current stored model in the thread.
* which means how many calls to CheckPoint we made so far.
* @return version Number.
* @throws XGBoostError
*/
public static int versionNumber() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitVersionNumber(out));
return out[0];
}
/**
* get rank of current thread.
* @return the rank.
* @throws XGBoostError
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetRank(out));
return out[0];
}
/**
* get world size of current job.
* @return the worldsize
* @throws XGBoostError
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetWorldSize(out));
return out[0];
}
}

View File

@ -72,15 +72,21 @@ public class XGBoost {
} }
//initialize booster //initialize booster
Booster booster = new JavaBoosterImpl(params, allMats); JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
int version = booster.loadRabitCheckpoint();
//begin to train //begin to train
for (int iter = 0; iter < round; iter++) { for (int iter = version / 2; iter < round; iter++) {
if (version % 2 == 0) {
if (obj != null) { if (obj != null) {
booster.update(dtrain, obj); booster.update(dtrain, obj);
} else { } else {
booster.update(dtrain, iter); booster.update(dtrain, iter);
} }
booster.saveRabitCheckpoint();
version += 1;
}
//evaluation //evaluation
if (evalMats != null && evalMats.length > 0) { if (evalMats != null && evalMats.length > 0) {
@ -90,9 +96,13 @@ public class XGBoost {
} else { } else {
evalInfo = booster.evalSet(evalMats, evalNames, iter); evalInfo = booster.evalSet(evalMats, evalNames, iter);
} }
logger.info(evalInfo); if (Rabit.getRank() == 0) {
Rabit.trackerPrint(evalInfo + '\n');
} }
} }
booster.saveRabitCheckpoint();
version += 1;
}
return booster; return booster;
} }

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
/** /**
* xgboost jni wrapper functions for xgboost_wrapper.h * xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
* *
* @author hzx * @author hzx
@ -80,4 +80,17 @@ class XgboostJNI {
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
String[][] out_strings); String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
// rabit functions
public final static native int RabitInit(String[] args);
public final static native int RabitFinalize();
public final static native int RabitTrackerPrint(String msg);
public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out);
} }

View File

@ -10,25 +10,28 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "xgboost/c_api.h" #include <xgboost/c_api.h>
#include "xgboost4j.h" #include "./xgboost4j.h"
#include <cstring> #include <cstring>
#include <vector>
#include <string>
//helper functions //helper functions
//set handle //set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
long out[1]; long out = (long) handle;
out[0] = (long) handle; jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
} }
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) { (JNIEnv *jenv, jclass jcls) {
jstring jresult = 0 ; jstring jresult = 0;
const char* result = XGBGetLastError(); const char* result = XGBGetLastError();
if (result) jresult = jenv->NewStringUTF(result); if (result != NULL) {
jresult = jenv->NewStringUTF(result);
}
return jresult; return jresult;
} }
@ -37,7 +40,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
DMatrixHandle result; DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
if (fname) jenv->ReleaseStringUTFChars(jfname, fname); if (fname) {
jenv->ReleaseStringUTFChars(jfname, fname);
}
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
return ret; return ret;
} }
@ -271,12 +276,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
DMatrixHandle* handles; DMatrixHandle* handles = NULL;
bst_ulong len = 0; bst_ulong len = 0;
jlong* cjhandles = 0; jlong* cjhandles = 0;
BoosterHandle result; BoosterHandle result;
if(jhandles) { if (jhandles) {
len = (bst_ulong)jenv->GetArrayLength(jhandles); len = (bst_ulong)jenv->GetArrayLength(jhandles);
handles = new DMatrixHandle[len]; handles = new DMatrixHandle[len];
//put handle from jhandles to chandles //put handle from jhandles to chandles
@ -288,12 +293,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
int ret = XGBoosterCreate(handles, len, &result); int ret = XGBoosterCreate(handles, len, &result);
//release //release
if(jhandles) { if (jhandles) {
delete[] handles; delete[] handles;
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
} }
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
return ret; return ret;
} }
@ -480,7 +484,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
char *result; char *result;
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result); int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
if (result){ if (result) {
jstring jinfo = jenv->NewStringUTF((const char *) result); jstring jinfo = jenv->NewStringUTF((const char *) result);
jenv->SetObjectArrayElement(jout, 0, jinfo); jenv->SetObjectArrayElement(jout, 0, jinfo);
} }
@ -511,3 +515,112 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
return ret; return ret;
} }
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
jenv->SetIntArrayRegion(jout, 0, 1, &version);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
std::vector<std::string> args;
std::vector<char*> argv;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
for (bst_ulong i = 0; i < len; ++i) {
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
std::string s(jenv->GetStringUTFChars(arg, 0),
jenv->GetStringLength(arg));
if (s.length() != 0) args.push_back(s);
}
for (size_t i = 0; i < args.size(); ++i) {
argv.push_back(&args[i][0]);
}
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
(JNIEnv *jenv, jclass jcls) {
RabitFinalize();
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
std::string str(jenv->GetStringUTFChars(jmsg, 0),
jenv->GetStringLength(jmsg));
RabitTrackerPrint(str.c_str());
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int rank = RabitGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitVersionNumber();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}

View File

@ -215,6 +215,86 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetAttr
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSetAttr
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
(JNIEnv *, jclass, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
(JNIEnv *, jclass, jintArray);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -8,11 +8,18 @@ import numpy as np
from .core import _LIB, c_str, STRING_TYPES from .core import _LIB, c_str, STRING_TYPES
def _init_rabit(): def _init_rabit():
"""Initialize the rabit library.""" """internal libary initializer."""
_LIB.RabitGetRank.restype = ctypes.c_int _LIB.RabitGetRank.restype = ctypes.c_int
_LIB.RabitGetWorldSize.restype = ctypes.c_int _LIB.RabitGetWorldSize.restype = ctypes.c_int
_LIB.RabitVersionNumber.restype = ctypes.c_int _LIB.RabitVersionNumber.restype = ctypes.c_int
_LIB.RabitInit(0, None)
def init(args=None):
"""Initialize the rabit libary with arguments"""
if args is None:
args = []
arr = (ctypes.c_char_p * len(args))()
arr[:] = args
_LIB.RabitInit(len(arr), arr)
def finalize(): def finalize():

2
rabit

@ -1 +1 @@
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0 Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043

View File

@ -4,6 +4,9 @@ import scipy.sparse
import pickle import pickle
import xgboost as xgb import xgboost as xgb
# always call this before using distributed module
xgb.rabit.init()
# Load file, file will be automatically sharded in distributed mode. # Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')