Merge pull request #904 from tqchen/master
[JVM-PKG] Update JNI to include Rabit interface
This commit is contained in:
commit
0f367a6ade
5
.gitignore
vendored
5
.gitignore
vendored
@ -20,8 +20,9 @@
|
|||||||
*buffer
|
*buffer
|
||||||
*model
|
*model
|
||||||
*pyc
|
*pyc
|
||||||
*train
|
*.train
|
||||||
*test
|
*.test
|
||||||
|
*.tar
|
||||||
*group
|
*group
|
||||||
*rar
|
*rar
|
||||||
*vali
|
*vali
|
||||||
|
|||||||
7
Makefile
7
Makefile
@ -73,7 +73,7 @@ endif
|
|||||||
|
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
.PHONY: clean all lint clean_all doxygen rcpplint Rpack Rbuild Rcheck java
|
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java
|
||||||
|
|
||||||
|
|
||||||
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
|
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
|
||||||
@ -143,6 +143,11 @@ clean_all: clean
|
|||||||
doxygen:
|
doxygen:
|
||||||
doxygen doc/Doxyfile
|
doxygen doc/Doxyfile
|
||||||
|
|
||||||
|
# create standalone python tar file.
|
||||||
|
pypack: ${XGBOOST_DYLIB}
|
||||||
|
cp ${XGBOOST_DYLIB} python-package/xgboost
|
||||||
|
cd python-package; tar cf xgboost.tar xgboost; cd ..
|
||||||
|
|
||||||
# Script to make a clean installable R package.
|
# Script to make a clean installable R package.
|
||||||
Rpack:
|
Rpack:
|
||||||
make clean_all
|
make clean_all
|
||||||
|
|||||||
5
jvm-packages/test_distributed.sh
Normal file
5
jvm-packages/test_distributed.sh
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Simple script to test distributed version, to be deleted later.
|
||||||
|
cd xgboost4j-demo
|
||||||
|
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain
|
||||||
|
cd ..
|
||||||
@ -0,0 +1,49 @@
|
|||||||
|
package ml.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.Rabit;
|
||||||
|
import ml.dmlc.xgboost4j.Booster;
|
||||||
|
import ml.dmlc.xgboost4j.DMatrix;
|
||||||
|
import ml.dmlc.xgboost4j.XGBoost;
|
||||||
|
import ml.dmlc.xgboost4j.XGBoostError;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Distributed training example, used to quick test distributed training.
|
||||||
|
*
|
||||||
|
* @author tqchen
|
||||||
|
*/
|
||||||
|
public class DistTrain {
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException, XGBoostError {
|
||||||
|
// always initialize rabit module before training.
|
||||||
|
Rabit.init(new HashMap<String, String>());
|
||||||
|
|
||||||
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
|
params.put("eta", 1.0);
|
||||||
|
params.put("max_depth", 2);
|
||||||
|
params.put("silent", 1);
|
||||||
|
params.put("objective", "binary:logistic");
|
||||||
|
|
||||||
|
|
||||||
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
|
||||||
|
//set round
|
||||||
|
int round = 2;
|
||||||
|
|
||||||
|
//train a boost model
|
||||||
|
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||||
|
|
||||||
|
// always shutdown rabit module after training.
|
||||||
|
Rabit.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DMatrix for xgboost, similar to the python wrapper xgboost.py
|
* DMatrix for xgboost.
|
||||||
*
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -48,4 +48,5 @@ class JNIErrorHandle {
|
|||||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster {
|
|||||||
return featureScore;
|
return featureScore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load the booster model from thread-local rabit checkpoint.
|
||||||
|
* This is only used in distributed training.
|
||||||
|
* @return the stored version number of the checkpoint.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
int loadRabitCheckpoint() throws XGBoostError {
|
||||||
|
int[] out = new int[1];
|
||||||
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||||
|
return out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the booster model into thread-local rabit checkpoint.
|
||||||
|
* This is only used in distributed training.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
void saveRabitCheckpoint() throws XGBoostError {
|
||||||
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* transfer DMatrix array to handle array (used for native functions)
|
* transfer DMatrix array to handle array (used for native functions)
|
||||||
*
|
*
|
||||||
|
|||||||
@ -0,0 +1,93 @@
|
|||||||
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rabit global class for synchronization.
|
||||||
|
*/
|
||||||
|
public class Rabit {
|
||||||
|
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||||
|
//load native library
|
||||||
|
static {
|
||||||
|
try {
|
||||||
|
NativeLibLoader.initXgBoost();
|
||||||
|
} catch (IOException ex) {
|
||||||
|
logger.error("load native library failed.");
|
||||||
|
logger.error(ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void checkCall(int ret) throws XGBoostError {
|
||||||
|
if (ret != 0) {
|
||||||
|
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize the rabit library on current working thread.
|
||||||
|
* @param envs The additional environment variables to pass to rabit.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static void init(Map<String, String> envs) throws XGBoostError {
|
||||||
|
String[] args = new String[envs.size()];
|
||||||
|
int idx = 0;
|
||||||
|
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||||
|
args[idx++] = e.getKey() + '=' + e.getValue();
|
||||||
|
}
|
||||||
|
checkCall(XgboostJNI.RabitInit(args));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shutdown the rabit engine in current working thread, equals to finalize.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static void shutdown() throws XGBoostError {
|
||||||
|
checkCall(XgboostJNI.RabitFinalize());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Print the message on rabit tracker.
|
||||||
|
* @param msg
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static void trackerPrint(String msg) throws XGBoostError {
|
||||||
|
checkCall(XgboostJNI.RabitTrackerPrint(msg));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get version number of current stored model in the thread.
|
||||||
|
* which means how many calls to CheckPoint we made so far.
|
||||||
|
* @return version Number.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static int versionNumber() throws XGBoostError {
|
||||||
|
int[] out = new int[1];
|
||||||
|
checkCall(XgboostJNI.RabitVersionNumber(out));
|
||||||
|
return out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get rank of current thread.
|
||||||
|
* @return the rank.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static int getRank() throws XGBoostError {
|
||||||
|
int[] out = new int[1];
|
||||||
|
checkCall(XgboostJNI.RabitGetRank(out));
|
||||||
|
return out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get world size of current job.
|
||||||
|
* @return the worldsize
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public static int getWorldSize() throws XGBoostError {
|
||||||
|
int[] out = new int[1];
|
||||||
|
checkCall(XgboostJNI.RabitGetWorldSize(out));
|
||||||
|
return out[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -72,14 +72,20 @@ public class XGBoost {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//initialize booster
|
//initialize booster
|
||||||
Booster booster = new JavaBoosterImpl(params, allMats);
|
JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
|
||||||
|
|
||||||
|
int version = booster.loadRabitCheckpoint();
|
||||||
|
|
||||||
//begin to train
|
//begin to train
|
||||||
for (int iter = 0; iter < round; iter++) {
|
for (int iter = version / 2; iter < round; iter++) {
|
||||||
if (obj != null) {
|
if (version % 2 == 0) {
|
||||||
booster.update(dtrain, obj);
|
if (obj != null) {
|
||||||
} else {
|
booster.update(dtrain, obj);
|
||||||
booster.update(dtrain, iter);
|
} else {
|
||||||
|
booster.update(dtrain, iter);
|
||||||
|
}
|
||||||
|
booster.saveRabitCheckpoint();
|
||||||
|
version += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//evaluation
|
//evaluation
|
||||||
@ -90,8 +96,12 @@ public class XGBoost {
|
|||||||
} else {
|
} else {
|
||||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||||
}
|
}
|
||||||
logger.info(evalInfo);
|
if (Rabit.getRank() == 0) {
|
||||||
|
Rabit.trackerPrint(evalInfo + '\n');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
booster.saveRabitCheckpoint();
|
||||||
|
version += 1;
|
||||||
}
|
}
|
||||||
return booster;
|
return booster;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
package ml.dmlc.xgboost4j;
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* xgboost jni wrapper functions for xgboost_wrapper.h
|
* xgboost JNI functions
|
||||||
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||||
*
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
@ -80,4 +80,17 @@ class XgboostJNI {
|
|||||||
|
|
||||||
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
|
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
|
||||||
String[][] out_strings);
|
String[][] out_strings);
|
||||||
|
|
||||||
|
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||||
|
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||||
|
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||||
|
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||||
|
|
||||||
|
// rabit functions
|
||||||
|
public final static native int RabitInit(String[] args);
|
||||||
|
public final static native int RabitFinalize();
|
||||||
|
public final static native int RabitTrackerPrint(String msg);
|
||||||
|
public final static native int RabitGetRank(int[] out);
|
||||||
|
public final static native int RabitGetWorldSize(int[] out);
|
||||||
|
public final static native int RabitVersionNumber(int[] out);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,45 +1,50 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "xgboost/c_api.h"
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
#include "xgboost4j.h"
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <xgboost/c_api.h>
|
||||||
|
#include "./xgboost4j.h"
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
//helper functions
|
//helper functions
|
||||||
//set handle
|
//set handle
|
||||||
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
||||||
long out[1];
|
long out = (long) handle;
|
||||||
out[0] = (long) handle;
|
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
|
||||||
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
||||||
(JNIEnv *jenv, jclass jcls) {
|
(JNIEnv *jenv, jclass jcls) {
|
||||||
jstring jresult = 0 ;
|
jstring jresult = 0;
|
||||||
const char* result = XGBGetLastError();
|
const char* result = XGBGetLastError();
|
||||||
if (result) jresult = jenv->NewStringUTF(result);
|
if (result != NULL) {
|
||||||
return jresult;
|
jresult = jenv->NewStringUTF(result);
|
||||||
|
}
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
||||||
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||||
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
|
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
|
if (fname) {
|
||||||
setHandle(jenv, jout, result);
|
jenv->ReleaseStringUTFChars(jfname, fname);
|
||||||
return ret;
|
}
|
||||||
|
setHandle(jenv, jout, result);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -49,19 +54,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
|
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
|
||||||
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
||||||
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
||||||
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||||
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||||
int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
|
int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
|
||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
//Release
|
//Release
|
||||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||||
jenv->ReleaseIntArrayElements(jindices, indices, 0);
|
jenv->ReleaseIntArrayElements(jindices, indices, 0);
|
||||||
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -71,21 +76,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
|
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
|
||||||
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
||||||
jfloat* data = jenv->GetFloatArrayElements(jdata, NULL);
|
jfloat* data = jenv->GetFloatArrayElements(jdata, NULL);
|
||||||
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||||
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||||
|
|
||||||
int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
|
int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
|
||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||||
jenv->ReleaseIntArrayElements(jindices, indices, 0);
|
jenv->ReleaseIntArrayElements(jindices, indices, 0);
|
||||||
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -95,15 +100,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
|
||||||
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
||||||
bst_ulong nrow = (bst_ulong)jnrow;
|
bst_ulong nrow = (bst_ulong)jnrow;
|
||||||
bst_ulong ncol = (bst_ulong)jncol;
|
bst_ulong ncol = (bst_ulong)jncol;
|
||||||
int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result);
|
int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result);
|
||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -113,18 +118,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
|
||||||
DMatrixHandle result;
|
DMatrixHandle result;
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
|
|
||||||
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
|
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset);
|
||||||
|
|
||||||
int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
|
int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
|
||||||
setHandle(jenv, jout, result);
|
setHandle(jenv, jout, result);
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
|
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -133,10 +138,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
|
|||||||
* Signature: (J)V
|
* Signature: (J)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
int ret = XGDMatrixFree(handle);
|
int ret = XGDMatrixFree(handle);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -146,11 +151,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||||
int ret = XGDMatrixSaveBinary(handle, fname, jsilent);
|
int ret = XGDMatrixSaveBinary(handle, fname, jsilent);
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -160,16 +165,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
|
|
||||||
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
|
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||||
int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len);
|
int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len);
|
||||||
//release
|
//release
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
jenv->ReleaseFloatArrayElements(jarray, array, 0);
|
jenv->ReleaseFloatArrayElements(jarray, array, 0);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -179,16 +184,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||||
int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
|
int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
|
||||||
//release
|
//release
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
||||||
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -198,13 +203,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
|
||||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
|
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||||
int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
|
int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -214,19 +219,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
bst_ulong len;
|
bst_ulong len;
|
||||||
float *result;
|
float *result;
|
||||||
int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result);
|
int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result);
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
|
|
||||||
jsize jlen = (jsize) len;
|
jsize jlen = (jsize) len;
|
||||||
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||||
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, (jobject) jarray);
|
jenv->SetObjectArrayElement(jout, 0, (jobject) jarray);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -236,18 +241,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
bst_ulong len;
|
bst_ulong len;
|
||||||
unsigned int *result;
|
unsigned int *result;
|
||||||
int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result);
|
int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result);
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
|
|
||||||
jsize jlen = (jsize) len;
|
jsize jlen = (jsize) len;
|
||||||
jintArray jarray = jenv->NewIntArray(jlen);
|
jintArray jarray = jenv->NewIntArray(jlen);
|
||||||
jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) result);
|
jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -257,11 +262,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
bst_ulong result[1];
|
bst_ulong result[1];
|
||||||
int ret = (jint) XGDMatrixNumRow(handle, result);
|
int ret = (jint) XGDMatrixNumRow(handle, result);
|
||||||
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result);
|
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -271,30 +276,29 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
|
||||||
DMatrixHandle* handles;
|
DMatrixHandle* handles = NULL;
|
||||||
bst_ulong len = 0;
|
bst_ulong len = 0;
|
||||||
jlong* cjhandles = 0;
|
jlong* cjhandles = 0;
|
||||||
BoosterHandle result;
|
BoosterHandle result;
|
||||||
|
|
||||||
if(jhandles) {
|
if (jhandles) {
|
||||||
len = (bst_ulong)jenv->GetArrayLength(jhandles);
|
len = (bst_ulong)jenv->GetArrayLength(jhandles);
|
||||||
handles = new DMatrixHandle[len];
|
handles = new DMatrixHandle[len];
|
||||||
//put handle from jhandles to chandles
|
//put handle from jhandles to chandles
|
||||||
cjhandles = jenv->GetLongArrayElements(jhandles, 0);
|
cjhandles = jenv->GetLongArrayElements(jhandles, 0);
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
for(bst_ulong i=0; i<len; i++) {
|
||||||
handles[i] = (DMatrixHandle) cjhandles[i];
|
handles[i] = (DMatrixHandle) cjhandles[i];
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
int ret = XGBoosterCreate(handles, len, &result);
|
|
||||||
//release
|
int ret = XGBoosterCreate(handles, len, &result);
|
||||||
if(jhandles) {
|
//release
|
||||||
delete[] handles;
|
if (jhandles) {
|
||||||
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
|
delete[] handles;
|
||||||
}
|
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
|
||||||
setHandle(jenv, jout, result);
|
}
|
||||||
|
setHandle(jenv, jout, result);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -316,14 +320,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
const char* name = jenv->GetStringUTFChars(jname, 0);
|
const char* name = jenv->GetStringUTFChars(jname, 0);
|
||||||
const char* value = jenv->GetStringUTFChars(jvalue, 0);
|
const char* value = jenv->GetStringUTFChars(jvalue, 0);
|
||||||
int ret = XGBoosterSetParam(handle, name, value);
|
int ret = XGBoosterSetParam(handle, name, value);
|
||||||
//release
|
//release
|
||||||
if (name) jenv->ReleaseStringUTFChars(jname, name);
|
if (name) jenv->ReleaseStringUTFChars(jname, name);
|
||||||
if (value) jenv->ReleaseStringUTFChars(jvalue, value);
|
if (value) jenv->ReleaseStringUTFChars(jvalue, value);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -333,9 +337,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
||||||
return XGBoosterUpdateOneIter(handle, jiter, dtrain);
|
return XGBoosterUpdateOneIter(handle, jiter, dtrain);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -345,16 +349,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
||||||
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
|
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
|
||||||
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
|
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
|
||||||
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
|
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
|
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
|
||||||
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
|
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -364,45 +368,45 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
DMatrixHandle* dmats = 0;
|
DMatrixHandle* dmats = 0;
|
||||||
char **evnames = 0;
|
char **evnames = 0;
|
||||||
char *result = 0;
|
char *result = 0;
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
|
||||||
if(len > 0) {
|
if(len > 0) {
|
||||||
dmats = new DMatrixHandle[len];
|
dmats = new DMatrixHandle[len];
|
||||||
evnames = new char*[len];
|
evnames = new char*[len];
|
||||||
}
|
}
|
||||||
//put handle from jhandles to chandles
|
//put handle from jhandles to chandles
|
||||||
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
|
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
|
||||||
|
for(bst_ulong i=0; i<len; i++) {
|
||||||
|
dmats[i] = (DMatrixHandle) cjdmats[i];
|
||||||
|
}
|
||||||
|
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace
|
||||||
|
for(bst_ulong i=0; i<len; i++) {
|
||||||
|
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
|
||||||
|
const char* cevname = jenv->GetStringUTFChars(jevname, 0);
|
||||||
|
evnames[i] = new char[jenv->GetStringLength(jevname)];
|
||||||
|
strcpy(evnames[i], cevname);
|
||||||
|
jenv->ReleaseStringUTFChars(jevname, cevname);
|
||||||
|
}
|
||||||
|
|
||||||
|
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
|
||||||
|
if(len > 0) {
|
||||||
|
delete[] dmats;
|
||||||
|
//release string chars
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
for(bst_ulong i=0; i<len; i++) {
|
||||||
dmats[i] = (DMatrixHandle) cjdmats[i];
|
delete[] evnames[i];
|
||||||
}
|
}
|
||||||
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace
|
delete[] evnames;
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
||||||
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
|
}
|
||||||
const char* cevname = jenv->GetStringUTFChars(jevname, 0);
|
|
||||||
evnames[i] = new char[jenv->GetStringLength(jevname)];
|
jstring jinfo = 0;
|
||||||
strcpy(evnames[i], cevname);
|
if (result) jinfo = jenv->NewStringUTF((const char *) result);
|
||||||
jenv->ReleaseStringUTFChars(jevname, cevname);
|
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||||
}
|
|
||||||
|
return ret;
|
||||||
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
|
|
||||||
if(len > 0) {
|
|
||||||
delete[] dmats;
|
|
||||||
//release string chars
|
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
|
||||||
delete[] evnames[i];
|
|
||||||
}
|
|
||||||
delete[] evnames;
|
|
||||||
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
jstring jinfo = 0;
|
|
||||||
if (result) jinfo = jenv->NewStringUTF((const char *) result);
|
|
||||||
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -412,17 +416,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
DMatrixHandle dmat = (DMatrixHandle) jdmat;
|
DMatrixHandle dmat = (DMatrixHandle) jdmat;
|
||||||
bst_ulong len;
|
bst_ulong len;
|
||||||
float *result;
|
float *result;
|
||||||
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
|
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
|
||||||
|
|
||||||
jsize jlen = (jsize) len;
|
jsize jlen = (jsize) len;
|
||||||
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||||
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -432,12 +436,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||||
|
|
||||||
int ret = XGBoosterLoadModel(handle, fname);
|
int ret = XGBoosterLoadModel(handle, fname);
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname,fname);
|
if (fname) jenv->ReleaseStringUTFChars(jfname,fname);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -447,13 +451,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||||
|
|
||||||
int ret = XGBoosterSaveModel(handle, fname);
|
int ret = XGBoosterSaveModel(handle, fname);
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
|
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -463,9 +467,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
void *buf = (void*) jbuf;
|
void *buf = (void*) jbuf;
|
||||||
return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen);
|
return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -475,16 +479,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromB
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
||||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
bst_ulong len = 0;
|
bst_ulong len = 0;
|
||||||
char *result;
|
char *result;
|
||||||
|
|
||||||
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
|
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
|
||||||
if (result){
|
if (result) {
|
||||||
jstring jinfo = jenv->NewStringUTF((const char *) result);
|
jstring jinfo = jenv->NewStringUTF((const char *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -494,20 +498,129 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
|||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
|
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
|
||||||
bst_ulong len = 0;
|
bst_ulong len = 0;
|
||||||
char **result;
|
char **result;
|
||||||
|
|
||||||
int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result);
|
int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result);
|
||||||
|
|
||||||
jsize jlen = (jsize) len;
|
jsize jlen = (jsize) len;
|
||||||
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
||||||
for(int i=0 ; i<jlen; i++) {
|
for(int i=0 ; i<jlen; i++) {
|
||||||
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i]));
|
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i]));
|
||||||
}
|
}
|
||||||
jenv->SetObjectArrayElement(jout, 0, jinfos);
|
jenv->SetObjectArrayElement(jout, 0, jinfos);
|
||||||
|
|
||||||
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
|
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
|
||||||
return ret;
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGBoosterLoadRabitCheckpoint
|
||||||
|
* Signature: (J[I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
|
||||||
|
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
|
||||||
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
|
int version;
|
||||||
|
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
|
||||||
|
jenv->SetIntArrayRegion(jout, 0, 1, &version);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGBoosterSaveRabitCheckpoint
|
||||||
|
* Signature: (J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
|
||||||
|
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||||
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
|
return XGBoosterSaveRabitCheckpoint(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitInit
|
||||||
|
* Signature: ([Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
||||||
|
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||||
|
std::vector<std::string> args;
|
||||||
|
std::vector<char*> argv;
|
||||||
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
|
||||||
|
for (bst_ulong i = 0; i < len; ++i) {
|
||||||
|
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
|
||||||
|
std::string s(jenv->GetStringUTFChars(arg, 0),
|
||||||
|
jenv->GetStringLength(arg));
|
||||||
|
if (s.length() != 0) args.push_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
|
argv.push_back(&args[i][0]);
|
||||||
|
}
|
||||||
|
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitFinalize
|
||||||
|
* Signature: ()I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
|
||||||
|
(JNIEnv *jenv, jclass jcls) {
|
||||||
|
RabitFinalize();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitTrackerPrint
|
||||||
|
* Signature: (Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
|
||||||
|
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
|
||||||
|
std::string str(jenv->GetStringUTFChars(jmsg, 0),
|
||||||
|
jenv->GetStringLength(jmsg));
|
||||||
|
RabitTrackerPrint(str.c_str());
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitGetRank
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
|
||||||
|
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||||
|
int rank = RabitGetRank();
|
||||||
|
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitGetWorldSize
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
|
||||||
|
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||||
|
int out = RabitGetWorldSize();
|
||||||
|
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitVersionNumber
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
|
||||||
|
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||||
|
int out = RabitVersionNumber();
|
||||||
|
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -215,6 +215,86 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
||||||
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
|
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGBoosterGetAttr
|
||||||
|
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr
|
||||||
|
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGBoosterSetAttr
|
||||||
|
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr
|
||||||
|
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGBoosterLoadRabitCheckpoint
|
||||||
|
* Signature: (J[I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
|
||||||
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: XGBoosterSaveRabitCheckpoint
|
||||||
|
* Signature: (J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
|
||||||
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitInit
|
||||||
|
* Signature: ([Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
||||||
|
(JNIEnv *, jclass, jobjectArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitFinalize
|
||||||
|
* Signature: ()I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
|
||||||
|
(JNIEnv *, jclass);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitTrackerPrint
|
||||||
|
* Signature: (Ljava/lang/String;)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
|
||||||
|
(JNIEnv *, jclass, jstring);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitGetRank
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
|
||||||
|
(JNIEnv *, jclass, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitGetWorldSize
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
|
||||||
|
(JNIEnv *, jclass, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||||
|
* Method: RabitVersionNumber
|
||||||
|
* Signature: ([I)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
|
||||||
|
(JNIEnv *, jclass, jintArray);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -8,11 +8,18 @@ import numpy as np
|
|||||||
from .core import _LIB, c_str, STRING_TYPES
|
from .core import _LIB, c_str, STRING_TYPES
|
||||||
|
|
||||||
def _init_rabit():
|
def _init_rabit():
|
||||||
"""Initialize the rabit library."""
|
"""internal libary initializer."""
|
||||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||||
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
||||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||||
_LIB.RabitInit(0, None)
|
|
||||||
|
def init(args=None):
|
||||||
|
"""Initialize the rabit libary with arguments"""
|
||||||
|
if args is None:
|
||||||
|
args = []
|
||||||
|
arr = (ctypes.c_char_p * len(args))()
|
||||||
|
arr[:] = args
|
||||||
|
_LIB.RabitInit(len(arr), arr)
|
||||||
|
|
||||||
|
|
||||||
def finalize():
|
def finalize():
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
|
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
|
||||||
@ -4,6 +4,9 @@ import scipy.sparse
|
|||||||
import pickle
|
import pickle
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
|
||||||
|
# always call this before using distributed module
|
||||||
|
xgb.rabit.init()
|
||||||
|
|
||||||
# Load file, file will be automatically sharded in distributed mode.
|
# Load file, file will be automatically sharded in distributed mode.
|
||||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user