Merge pull request #907 from tqchen/master

[DIST] Enable multiple thread  make rabit and xgboost threadsafe
This commit is contained in:
Tianqi Chen 2016-03-04 08:24:00 -08:00
commit 770b3451ca
18 changed files with 368 additions and 112 deletions

1
.gitignore vendored
View File

@ -78,5 +78,4 @@ tags
*.iml *.iml
*.class *.class
target target
*.swp *.swp

View File

@ -22,6 +22,11 @@ This file records the changes in xgboost library in reverse chronological order.
- The windows version is still blocked due to Rtools do not support ```std::thread```. - The windows version is still blocked due to Rtools do not support ```std::thread```.
* rabit and dmlc-core are maintained through git submodule * rabit and dmlc-core are maintained through git submodule
- Anyone can open PR to update these dependencies now. - Anyone can open PR to update these dependencies now.
* Improvements
- Rabit and xgboost libs are not thread-safe and use thread local PRNGs
- This could fix some of the previous problem which runs xgboost on multiple threads.
* JVM Package
- Enable xgboost4j for java and scala
## v0.47 (2016.01.14) ## v0.47 (2016.01.14)

2
jvm-packages/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
tracker.py
build.sh

View File

@ -27,6 +27,8 @@ fi
rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl} rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl} mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl}
# copy python to native resources
cp ../dmlc-core/tracker/dmlc_tracker/tracker.py xgboost4j/src/main/resources/tracker.py
popd > /dev/null popd > /dev/null
echo "complete" echo "complete"

2
jvm-packages/test_distributed.sh Normal file → Executable file
View File

@ -1,5 +1,5 @@
#!/bin/bash #!/bin/bash
# Simple script to test distributed version, to be deleted later. # Simple script to test distributed version, to be deleted later.
cd xgboost4j-demo cd xgboost4j-demo
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
cd .. cd ..

View File

@ -2,19 +2,37 @@ package ml.dmlc.xgboost4j.demo;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.*; import ml.dmlc.xgboost4j.*;
/** /**
* Distributed training example, used to quick test distributed training. * Distributed training example, used to quick test distributed training.
* *
* @author tqchen * @author tqchen
*/ */
public class DistTrain { public class DistTrain {
private static final Log logger = LogFactory.getLog(DistTrain.class);
private Map<String, String> envs = null;
public static void main(String[] args) throws IOException, XGBoostError { private class Worker implements Runnable {
private final int workerId;
Worker(int workerId) {
this.workerId = workerId;
}
public void run() {
try {
Map<String, String> worker_env = new HashMap<String, String>(envs);
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
// always initialize rabit module before training. // always initialize rabit module before training.
Rabit.init(new HashMap<String, String>()); Rabit.init(worker_env);
// load file from text file, also binary buffer generated by xgboost4j // load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
@ -24,9 +42,9 @@ public class DistTrain {
params.put("eta", 1.0); params.put("eta", 1.0);
params.put("max_depth", 2); params.put("max_depth", 2);
params.put("silent", 1); params.put("silent", 1);
params.put("nthread", 2);
params.put("objective", "binary:logistic"); params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>(); HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat); watches.put("train", trainMat);
watches.put("test", testMat); watches.put("test", testMat);
@ -39,5 +57,24 @@ public class DistTrain {
// always shutdown rabit module after training. // always shutdown rabit module after training.
Rabit.shutdown(); Rabit.shutdown();
} catch (Exception ex){
logger.error(ex);
}
}
}
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
RabitTracker tracker = new RabitTracker(nWorkers);
if (tracker.start()) {
envs = tracker.getWorkerEnvs();
for (int i = 0; i < nWorkers; ++i) {
new Thread(new Worker(i)).start();
}
tracker.waitFor();
}
}
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
new DistTrain().start(Integer.parseInt(args[0]));
} }
} }

View File

@ -1,9 +1,10 @@
package ml.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable;
import java.util.Map; import java.util.Map;
public interface Booster { public interface Booster extends Serializable {
/** /**
* set parameter * set parameter
@ -109,12 +110,25 @@ public interface Booster {
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError; float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
/** /**
* save model to modelPath * save model to modelPath, the model path support depends on the path support
* * in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
* compiled with HDFS support.
* See also toByteArray
* @param modelPath model path * @param modelPath model path
*/ */
void saveModel(String modelPath) throws XGBoostError; void saveModel(String modelPath) throws XGBoostError;
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
*
* @return the saved byte array.
* @throws XGBoostError
*/
byte[] toByteArray() throws XGBoostError;
/** /**
* Dump model into a text file. * Dump model into a text file.
* *

View File

@ -57,7 +57,6 @@ class JavaBoosterImpl implements Booster {
setParams(params); setParams(params);
} }
/** /**
* load model from modelPath * load model from modelPath
* *
@ -440,6 +439,22 @@ class JavaBoosterImpl implements Booster {
return featureScore; return featureScore;
} }
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
*
* @return the saved byte array.
* @throws XGBoostError
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
/** /**
* Load the booster model from thread-local rabit checkpoint. * Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training. * This is only used in distributed training.
@ -475,6 +490,27 @@ class JavaBoosterImpl implements Booster {
return handles; return handles;
} }
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out)
throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
private void readObject(java.io.ObjectInputStream in)
throws IOException, ClassNotFoundException {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
@Override @Override
protected void finalize() throws Throwable { protected void finalize() throws Throwable {
super.finalize(); super.finalize();

View File

@ -21,6 +21,7 @@ import java.lang.reflect.Field;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
/** /**
* class to load native library * class to load native library
* *
@ -61,12 +62,32 @@ class NativeLibLoader {
* three characters * three characters
*/ */
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{ private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
String temp = createTempFileFromResource(path);
// Finally, load the library
System.load(temp);
}
/**
* Create a temp file that copies the resource from current JAR archive
* <p/>
* The file from JAR is copied into system temp file.
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p/>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
* {@code path}.
* @param path Path to the resources in the jar
* @return The created temp file.
* @throws IOException
* @throws IllegalArgumentException
*/
static String createTempFileFromResource(String path) throws
IOException, IllegalArgumentException {
// Obtain filename from path
if (!path.startsWith("/")) { if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/')."); throw new IllegalArgumentException("The path has to be absolute (start with '/').");
} }
// Obtain filename from path
String[] parts = path.split("/"); String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null; String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
@ -83,7 +104,6 @@ class NativeLibLoader {
if (filename == null || prefix.length() < 3) { if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long."); throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
} }
// Prepare temporary file // Prepare temporary file
File temp = File.createTempFile(prefix, suffix); File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit(); temp.deleteOnExit();
@ -113,9 +133,7 @@ class NativeLibLoader {
os.close(); os.close();
is.close(); is.close();
} }
return temp.getAbsolutePath();
// Finally, load the library
System.load(temp.getAbsolutePath());
} }
/** /**
@ -133,8 +151,9 @@ class NativeLibLoader {
try { try {
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName); String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
loadLibraryFromJar(libraryFromJar); loadLibraryFromJar(libraryFromJar);
} catch (IOException e1) { } catch (IOException ioe) {
throw e1; logger.error("failed to load library from both native path and jar");
throw ioe;
} }
} }
} }

View File

@ -0,0 +1,144 @@
package ml.dmlc.xgboost4j;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Distributed RabitTracker, need to be started on driver code before running distributed jobs.
*/
public class RabitTracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file.
private static String tracker_py = null;
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int num_workers;
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
static {
try {
initTrackerPy();
} catch (IOException ex) {
logger.error("load tracker library failed.");
logger.error(ex);
}
}
/**
* Tracker logger that logs output from tracker.
*/
private class TrackerProcessLogger implements Runnable {
public void run() {
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
BufferedReader reader = new BufferedReader(new InputStreamReader(
trackerProcess.get().getErrorStream()));
String line;
try {
while ((line = reader.readLine()) != null) {
trackerProcessLogger.info(line);
}
} catch (IOException ex) {
trackerProcessLogger.error(ex.toString());
}
}
}
private static void initTrackerPy() throws IOException {
try {
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
} catch (IOException ioe) {
logger.trace("cannot access tracker python script");
throw ioe;
}
}
public RabitTracker(int num_workers) {
this.num_workers = num_workers;
}
/**
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, String> getWorkerEnvs() {
return envs;
}
private void loadEnvs(InputStream ins) throws IOException {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
break;
}
String[] sep = line.split("=");
if (sep.length == 2) {
envs.put(sep[0], sep[1]);
}
}
reader.close();
} catch (IOException ioe){
logger.error("cannot get runtime configuration from tracker process");
ioe.printStackTrace();
throw ioe;
}
}
private boolean startTrackerProcess() {
try {
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
" --num-workers=" + String.valueOf(num_workers)));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {
ioe.printStackTrace();
return false;
}
}
private void stop() {
if (trackerProcess.get() != null) {
trackerProcess.get().destroy();
}
}
public boolean start() {
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
// also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true);
logger_thread.start();
return true;
} else {
logger.error("FAULT: failed to start tracker process");
stop();
return false;
}
}
public void waitFor() {
try {
trackerProcess.get().waitFor();
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
stop();
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
}
}
}

View File

@ -74,9 +74,9 @@ class XgboostJNI {
public final static native int XGBoosterSaveModel(long handle, String fname); public final static native int XGBoosterSaveModel(long handle, String fname);
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len); public final static native int XGBoosterLoadModelFromBuffer(long handle, byte[] bytes);
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string); public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes);
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
String[][] out_strings); String[][] out_strings);

View File

@ -23,7 +23,7 @@ import scala.collection.mutable
import ml.dmlc.xgboost4j.XGBoostError import ml.dmlc.xgboost4j.XGBoostError
trait Booster { trait Booster extends Serializable {
/** /**

View File

@ -13,6 +13,8 @@
*/ */
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include "./xgboost4j.h" #include "./xgboost4j.h"
#include <cstring> #include <cstring>
#include <vector> #include <vector>
@ -276,27 +278,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
DMatrixHandle* handles = NULL; std::vector<DMatrixHandle> handles;
bst_ulong len = 0; if (jhandles != nullptr) {
jlong* cjhandles = 0; size_t len = jenv->GetArrayLength(jhandles);
BoosterHandle result; jlong *cjhandles = jenv->GetLongArrayElements(jhandles, 0);
for (size_t i = 0; i < len; ++i) {
if (jhandles) { handles.push_back((DMatrixHandle) cjhandles[i]);
len = (bst_ulong)jenv->GetArrayLength(jhandles);
handles = new DMatrixHandle[len];
//put handle from jhandles to chandles
cjhandles = jenv->GetLongArrayElements(jhandles, 0);
for(bst_ulong i=0; i<len; i++) {
handles[i] = (DMatrixHandle) cjhandles[i];
} }
}
int ret = XGBoosterCreate(handles, len, &result);
//release
if (jhandles) {
delete[] handles;
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
} }
BoosterHandle result;
int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
return ret; return ret;
} }
@ -369,43 +361,34 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle* dmats = 0; std::vector<DMatrixHandle> dmats;
char **evnames = 0; std::vector<std::string> evnames;
char *result = 0; std::vector<const char*> evchars;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
if(len > 0) { size_t len = static_cast<size_t>(jenv->GetArrayLength(jdmats));
dmats = new DMatrixHandle[len];
evnames = new char*[len];
}
// put handle from jhandles to chandles // put handle from jhandles to chandles
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0); jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
for(bst_ulong i=0; i<len; i++) { for (size_t i = 0; i < len; ++i) {
dmats[i] = (DMatrixHandle) cjdmats[i]; dmats.push_back((DMatrixHandle) cjdmats[i]);
}
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace
for(bst_ulong i=0; i<len; i++) {
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i); jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
const char* cevname = jenv->GetStringUTFChars(jevname, 0); const char *s =jenv->GetStringUTFChars(jevname, 0);
evnames[i] = new char[jenv->GetStringLength(jevname)]; evnames.push_back(std::string(s, jenv->GetStringLength(jevname)));
strcpy(evnames[i], cevname); if (s != nullptr) jenv->ReleaseStringUTFChars(jevname, s);
jenv->ReleaseStringUTFChars(jevname, cevname);
} }
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
if(len > 0) {
delete[] dmats;
//release string chars
for(bst_ulong i=0; i<len; i++) {
delete[] evnames[i];
}
delete[] evnames;
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
for (size_t i = 0; i < len; ++i) {
evchars.push_back(evnames[i].c_str());
}
const char* result;
int ret = XGBoosterEvalOneIter(handle, jiter,
dmlc::BeginPtr(dmats),
dmlc::BeginPtr(evchars),
len, &result);
jstring jinfo = nullptr;
if (result != nullptr) {
jinfo = jenv->NewStringUTF(result);
} }
jstring jinfo = 0;
if (result) jinfo = jenv->NewStringUTF((const char *) result);
jenv->SetObjectArrayElement(jout, 0, jinfo); jenv->SetObjectArrayElement(jout, 0, jinfo);
return ret; return ret;
} }
@ -456,37 +439,40 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
int ret = XGBoosterSaveModel(handle, fname); int ret = XGBoosterSaveModel(handle, fname);
if (fname) jenv->ReleaseStringUTFChars(jfname, fname); if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
return ret; return ret;
} }
/* /*
* Class: ml_dmlc_xgboost4j_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer * Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)V * Signature: (J[B)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
void *buf = (void*) jbuf; jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0);
return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen); int ret = XGBoosterLoadModelFromBuffer(
handle, buffer, jenv->GetArrayLength(jbytes));
jenv->ReleaseByteArrayElements(jbytes, buffer, 0);
return ret;
} }
/* /*
* Class: ml_dmlc_xgboost4j_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetModelRaw * Method: XGBoosterGetModelRaw
* Signature: (J)Ljava/lang/String; * Signature: (J[[B)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong len = 0; bst_ulong len = 0;
char *result; const char* result;
int ret = XGBoosterGetModelRaw(handle, &len, &result);
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
if (result) { if (result) {
jstring jinfo = jenv->NewStringUTF((const char *) result); jbyteArray jarray = jenv->NewByteArray(len);
jenv->SetObjectArrayElement(jout, 0, jinfo); jenv->SetByteArrayRegion(jarray, 0, len, (jbyte*)result);
jenv->SetObjectArrayElement(jout, 0, jarray);
} }
return ret; return ret;
} }
@ -553,15 +539,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
for (bst_ulong i = 0; i < len; ++i) { for (bst_ulong i = 0; i < len; ++i) {
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i); jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
std::string s(jenv->GetStringUTFChars(arg, 0), const char *s = jenv->GetStringUTFChars(arg, 0);
jenv->GetStringLength(arg)); args.push_back(std::string(s, jenv->GetStringLength(arg)));
if (s.length() != 0) args.push_back(s); if (s != nullptr) jenv->ReleaseStringUTFChars(arg, s);
if (args.back().length() == 0) args.pop_back();
} }
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
argv.push_back(&args[i][0]); argv.push_back(&args[i][0]);
} }
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
RabitInit(args.size(), dmlc::BeginPtr(argv));
return 0; return 0;
} }

View File

@ -194,15 +194,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
/* /*
* Class: ml_dmlc_xgboost4j_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer * Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)I * Signature: (J[B)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *, jclass, jlong, jlong, jlong); (JNIEnv *, jclass, jlong, jbyteArray);
/* /*
* Class: ml_dmlc_xgboost4j_XgboostJNI * Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetModelRaw * Method: XGBoosterGetModelRaw
* Signature: (J[Ljava/lang/String;)I * Signature: (J[[B)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv *, jclass, jlong, jobjectArray); (JNIEnv *, jclass, jlong, jobjectArray);

View File

@ -4,12 +4,20 @@
* \brief Enable all kinds of global variables in common. * \brief Enable all kinds of global variables in common.
*/ */
#include "./random.h" #include "./random.h"
#include "./thread_local.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/*! \brief thread local entry for random. */
struct RandomThreadLocalEntry {
/*! \brief the random engine instance. */
GlobalRandomEngine engine;
};
typedef ThreadLocalStore<RandomThreadLocalEntry> RandomThreadLocalStore;
GlobalRandomEngine& GlobalRandom() { GlobalRandomEngine& GlobalRandom() {
static GlobalRandomEngine inst; return RandomThreadLocalStore::Get()->engine;
return inst;
}
} }
} // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -61,7 +61,8 @@ typedef RandomEngine GlobalRandomEngine;
/*! /*!
* \brief global singleton of a random engine. * \brief global singleton of a random engine.
* Only use this engine when necessary, not thread-safe. * This random engine is thread-local and
* only visible to current thread.
*/ */
GlobalRandomEngine& GlobalRandom(); // NOLINT(*) GlobalRandomEngine& GlobalRandom(); // NOLINT(*)

View File

@ -6,6 +6,8 @@
#ifndef XGBOOST_COMMON_THREAD_LOCAL_H_ #ifndef XGBOOST_COMMON_THREAD_LOCAL_H_
#define XGBOOST_COMMON_THREAD_LOCAL_H_ #define XGBOOST_COMMON_THREAD_LOCAL_H_
#include <dmlc/base.h>
#if DMLC_ENABLE_STD_THREAD #if DMLC_ENABLE_STD_THREAD
#include <mutex> #include <mutex>
#endif #endif

View File

@ -15,6 +15,7 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <limits> #include <limits>
#include "../common/common.h"
namespace xgboost { namespace xgboost {
namespace gbm { namespace gbm {
@ -265,13 +266,11 @@ class GBTree : public GradientBooster {
inline void InitUpdater() { inline void InitUpdater() {
if (updaters.size() != 0) return; if (updaters.size() != 0) return;
std::string tval = tparam.updater_seq; std::string tval = tparam.updater_seq;
char *pstr; std::vector<std::string> ups = common::Split(tval, ',');
pstr = std::strtok(&tval[0], ","); for (const std::string& pstr : ups) {
while (pstr != nullptr) { std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr.c_str()));
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr));
up->Init(this->cfg); up->Init(this->cfg);
updaters.push_back(std::move(up)); updaters.push_back(std::move(up));
pstr = std::strtok(nullptr, ",");
} }
} }
// do group specific group // do group specific group