Merge pull request #904 from tqchen/master
[JVM-PKG] Update JNI to include Rabit interface
This commit is contained in:
commit
0f367a6ade
5
.gitignore
vendored
5
.gitignore
vendored
@ -20,8 +20,9 @@
|
||||
*buffer
|
||||
*model
|
||||
*pyc
|
||||
*train
|
||||
*test
|
||||
*.train
|
||||
*.test
|
||||
*.tar
|
||||
*group
|
||||
*rar
|
||||
*vali
|
||||
|
||||
7
Makefile
7
Makefile
@ -73,7 +73,7 @@ endif
|
||||
|
||||
|
||||
# specify tensor path
|
||||
.PHONY: clean all lint clean_all doxygen rcpplint Rpack Rbuild Rcheck java
|
||||
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java
|
||||
|
||||
|
||||
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
|
||||
@ -143,6 +143,11 @@ clean_all: clean
|
||||
doxygen:
|
||||
doxygen doc/Doxyfile
|
||||
|
||||
# create standalone python tar file.
|
||||
pypack: ${XGBOOST_DYLIB}
|
||||
cp ${XGBOOST_DYLIB} python-package/xgboost
|
||||
cd python-package; tar cf xgboost.tar xgboost; cd ..
|
||||
|
||||
# Script to make a clean installable R package.
|
||||
Rpack:
|
||||
make clean_all
|
||||
|
||||
5
jvm-packages/test_distributed.sh
Normal file
5
jvm-packages/test_distributed.sh
Normal file
@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
# Simple script to test distributed version, to be deleted later.
|
||||
cd xgboost4j-demo
|
||||
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain
|
||||
cd ..
|
||||
@ -0,0 +1,49 @@
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.Rabit;
|
||||
import ml.dmlc.xgboost4j.Booster;
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.XGBoost;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
|
||||
/**
|
||||
* Distributed training example, used to quick test distributed training.
|
||||
*
|
||||
* @author tqchen
|
||||
*/
|
||||
public class DistTrain {
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
// always initialize rabit module before training.
|
||||
Rabit.init(new HashMap<String, String>());
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
// always shutdown rabit module after training.
|
||||
Rabit.shutdown();
|
||||
}
|
||||
}
|
||||
@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* DMatrix for xgboost, similar to the python wrapper xgboost.py
|
||||
* DMatrix for xgboost.
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
|
||||
@ -48,4 +48,5 @@ class JNIErrorHandle {
|
||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster {
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the booster model from thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @return the stored version number of the checkpoint.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the booster model into thread-local rabit checkpoint.
|
||||
* This is only used in distributed training.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* transfer DMatrix array to handle array (used for native functions)
|
||||
*
|
||||
|
||||
@ -0,0 +1,93 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import java.util.Map;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Rabit global class for synchronization.
|
||||
*/
|
||||
public class Rabit {
|
||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the rabit library on current working thread.
|
||||
* @param envs The additional environment variables to pass to rabit.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||
String[] args = new String[envs.size()];
|
||||
int idx = 0;
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey() + '=' + e.getValue();
|
||||
}
|
||||
checkCall(XgboostJNI.RabitInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown the rabit engine in current working thread, equals to finalize.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void shutdown() throws XGBoostError {
|
||||
checkCall(XgboostJNI.RabitFinalize());
|
||||
}
|
||||
|
||||
/**
|
||||
* Print the message on rabit tracker.
|
||||
* @param msg
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void trackerPrint(String msg) throws XGBoostError {
|
||||
checkCall(XgboostJNI.RabitTrackerPrint(msg));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get version number of current stored model in the thread.
|
||||
* which means how many calls to CheckPoint we made so far.
|
||||
* @return version Number.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int versionNumber() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitVersionNumber(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get rank of current thread.
|
||||
* @return the rank.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int getRank() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitGetRank(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get world size of current job.
|
||||
* @return the worldsize
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static int getWorldSize() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitGetWorldSize(out));
|
||||
return out[0];
|
||||
}
|
||||
}
|
||||
@ -72,15 +72,21 @@ public class XGBoost {
|
||||
}
|
||||
|
||||
//initialize booster
|
||||
Booster booster = new JavaBoosterImpl(params, allMats);
|
||||
JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
|
||||
|
||||
int version = booster.loadRabitCheckpoint();
|
||||
|
||||
//begin to train
|
||||
for (int iter = 0; iter < round; iter++) {
|
||||
for (int iter = version / 2; iter < round; iter++) {
|
||||
if (version % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
version += 1;
|
||||
}
|
||||
|
||||
//evaluation
|
||||
if (evalMats != null && evalMats.length > 0) {
|
||||
@ -90,9 +96,13 @@ public class XGBoost {
|
||||
} else {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||
}
|
||||
logger.info(evalInfo);
|
||||
if (Rabit.getRank() == 0) {
|
||||
Rabit.trackerPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
version += 1;
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
/**
|
||||
* xgboost jni wrapper functions for xgboost_wrapper.h
|
||||
* xgboost JNI functions
|
||||
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||
*
|
||||
* @author hzx
|
||||
@ -80,4 +80,17 @@ class XgboostJNI {
|
||||
|
||||
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
|
||||
String[][] out_strings);
|
||||
|
||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
|
||||
// rabit functions
|
||||
public final static native int RabitInit(String[] args);
|
||||
public final static native int RabitFinalize();
|
||||
public final static native int RabitTrackerPrint(String msg);
|
||||
public final static native int RabitGetRank(int[] out);
|
||||
public final static native int RabitGetWorldSize(int[] out);
|
||||
public final static native int RabitVersionNumber(int[] out);
|
||||
}
|
||||
|
||||
@ -12,23 +12,26 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost4j.h"
|
||||
#include <xgboost/c_api.h>
|
||||
#include "./xgboost4j.h"
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
//helper functions
|
||||
//set handle
|
||||
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
||||
long out[1];
|
||||
out[0] = (long) handle;
|
||||
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
|
||||
long out = (long) handle;
|
||||
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
jstring jresult = 0;
|
||||
const char* result = XGBGetLastError();
|
||||
if (result) jresult = jenv->NewStringUTF(result);
|
||||
if (result != NULL) {
|
||||
jresult = jenv->NewStringUTF(result);
|
||||
}
|
||||
return jresult;
|
||||
}
|
||||
|
||||
@ -37,7 +40,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
||||
DMatrixHandle result;
|
||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
|
||||
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
|
||||
if (fname) {
|
||||
jenv->ReleaseStringUTFChars(jfname, fname);
|
||||
}
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
@ -271,7 +276,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
|
||||
DMatrixHandle* handles;
|
||||
DMatrixHandle* handles = NULL;
|
||||
bst_ulong len = 0;
|
||||
jlong* cjhandles = 0;
|
||||
BoosterHandle result;
|
||||
@ -293,7 +298,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
||||
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
|
||||
}
|
||||
setHandle(jenv, jout, result);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -511,3 +515,112 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
||||
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
int version;
|
||||
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &version);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
return XGBoosterSaveRabitCheckpoint(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||
std::vector<std::string> args;
|
||||
std::vector<char*> argv;
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||
for (bst_ulong i = 0; i < len; ++i) {
|
||||
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
|
||||
std::string s(jenv->GetStringUTFChars(arg, 0),
|
||||
jenv->GetStringLength(arg));
|
||||
if (s.length() != 0) args.push_back(s);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
argv.push_back(&args[i][0]);
|
||||
}
|
||||
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
RabitFinalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitTrackerPrint
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
|
||||
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
|
||||
std::string str(jenv->GetStringUTFChars(jmsg, 0),
|
||||
jenv->GetStringLength(jmsg));
|
||||
RabitTrackerPrint(str.c_str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitGetRank
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int rank = RabitGetRank();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitGetWorldSize
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int out = RabitGetWorldSize();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitVersionNumber
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int out = RabitVersionNumber();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -215,6 +215,86 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
||||
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: XGBoosterGetAttr
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: XGBoosterSetAttr
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
||||
(JNIEnv *, jclass, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitTrackerPrint
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
|
||||
(JNIEnv *, jclass, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitGetRank
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitGetWorldSize
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Method: RabitVersionNumber
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -8,11 +8,18 @@ import numpy as np
|
||||
from .core import _LIB, c_str, STRING_TYPES
|
||||
|
||||
def _init_rabit():
|
||||
"""Initialize the rabit library."""
|
||||
"""internal libary initializer."""
|
||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||
_LIB.RabitInit(0, None)
|
||||
|
||||
def init(args=None):
|
||||
"""Initialize the rabit libary with arguments"""
|
||||
if args is None:
|
||||
args = []
|
||||
arr = (ctypes.c_char_p * len(args))()
|
||||
arr[:] = args
|
||||
_LIB.RabitInit(len(arr), arr)
|
||||
|
||||
|
||||
def finalize():
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
|
||||
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
|
||||
@ -4,6 +4,9 @@ import scipy.sparse
|
||||
import pickle
|
||||
import xgboost as xgb
|
||||
|
||||
# always call this before using distributed module
|
||||
xgb.rabit.init()
|
||||
|
||||
# Load file, file will be automatically sharded in distributed mode.
|
||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user