[refactor] move java package to namespace java
This commit is contained in:
@@ -65,7 +65,7 @@ public class DMatrix {
|
||||
logger.info(e.toString());
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ public class DMatrix {
|
||||
throw new NullPointerException("dataPath: null");
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@@ -95,9 +95,9 @@ public class DMatrix {
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
if (st == SparseType.CSR) {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
||||
} else if (st == SparseType.CSC) {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
||||
} else {
|
||||
throw new UnknownError("unknow sparsetype");
|
||||
}
|
||||
@@ -114,7 +114,7 @@ public class DMatrix {
|
||||
*/
|
||||
public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -143,7 +143,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -154,7 +154,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -176,18 +176,18 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
|
||||
}
|
||||
|
||||
private float[] getFloatInfo(String field) throws XGBoostError {
|
||||
float[][] infos = new float[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
||||
return infos[0];
|
||||
}
|
||||
|
||||
private int[] getIntInfo(String field) throws XGBoostError {
|
||||
int[][] infos = new int[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
|
||||
return infos[0];
|
||||
}
|
||||
|
||||
@@ -230,7 +230,7 @@ public class DMatrix {
|
||||
*/
|
||||
public DMatrix slice(int[] rowIndex) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
|
||||
long sHandle = out[0];
|
||||
DMatrix sMatrix = new DMatrix(sHandle);
|
||||
return sMatrix;
|
||||
@@ -244,7 +244,7 @@ public class DMatrix {
|
||||
*/
|
||||
public long rowNum() throws XGBoostError {
|
||||
long[] rowNum = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle, rowNum));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum));
|
||||
return rowNum[0];
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ public class DMatrix {
|
||||
* save DMatrix to filePath
|
||||
*/
|
||||
public void saveBinary(String filePath) {
|
||||
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
||||
XGBoostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -285,7 +285,7 @@ public class DMatrix {
|
||||
|
||||
public synchronized void dispose() {
|
||||
if (handle != 0) {
|
||||
XgboostJNI.XGDMatrixFree(handle);
|
||||
XGBoostJNI.XGDMatrixFree(handle);
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ class JNIErrorHandle {
|
||||
*/
|
||||
static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ class JavaBoosterImpl implements Booster {
|
||||
handles = dmatrixsToHandles(dMatrixs);
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
|
||||
|
||||
handle = out[0];
|
||||
}
|
||||
@@ -94,7 +94,7 @@ class JavaBoosterImpl implements Booster {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public final void setParam(String key, String value) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -120,7 +120,7 @@ class JavaBoosterImpl implements Booster {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void update(DMatrix dtrain, int iter) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -149,7 +149,7 @@ class JavaBoosterImpl implements Booster {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||
hess));
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ class JavaBoosterImpl implements Booster {
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
||||
long[] handles = dmatrixsToHandles(evalMatrixs);
|
||||
String[] evalInfo = new String[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||
evalInfo));
|
||||
return evalInfo[0];
|
||||
}
|
||||
@@ -211,7 +211,7 @@ class JavaBoosterImpl implements Booster {
|
||||
optionMask = 2;
|
||||
}
|
||||
float[][] rawPredicts = new float[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||
treeLimit, rawPredicts));
|
||||
int row = (int) data.rowNum();
|
||||
int col = rawPredicts[0].length / row;
|
||||
@@ -284,11 +284,11 @@ class JavaBoosterImpl implements Booster {
|
||||
* @param modelPath model path
|
||||
*/
|
||||
public void saveModel(String modelPath) throws XGBoostError{
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveModel(handle, modelPath));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
|
||||
}
|
||||
|
||||
private void loadModel(String modelPath) {
|
||||
XgboostJNI.XGBoosterLoadModel(handle, modelPath);
|
||||
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -304,7 +304,7 @@ class JavaBoosterImpl implements Booster {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
@@ -322,7 +322,7 @@ class JavaBoosterImpl implements Booster {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
@@ -450,7 +450,7 @@ class JavaBoosterImpl implements Booster {
|
||||
*/
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
byte[][] bytes = new byte[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
@@ -463,7 +463,7 @@ class JavaBoosterImpl implements Booster {
|
||||
*/
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
@@ -473,7 +473,7 @@ class JavaBoosterImpl implements Booster {
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -504,7 +504,7 @@ class JavaBoosterImpl implements Booster {
|
||||
try {
|
||||
this.init(null);
|
||||
byte[] bytes = (byte[])in.readObject();
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
}
|
||||
@@ -518,7 +518,7 @@ class JavaBoosterImpl implements Booster {
|
||||
|
||||
public synchronized void dispose() {
|
||||
if (handle != 0L) {
|
||||
XgboostJNI.XGBoosterFree(handle);
|
||||
XGBoostJNI.XGBoosterFree(handle);
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ public class Rabit {
|
||||
|
||||
private static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ public class Rabit {
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey() + '=' + e.getValue();
|
||||
}
|
||||
checkCall(XgboostJNI.RabitInit(args));
|
||||
checkCall(XGBoostJNI.RabitInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -46,7 +46,7 @@ public class Rabit {
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void shutdown() throws XGBoostError {
|
||||
checkCall(XgboostJNI.RabitFinalize());
|
||||
checkCall(XGBoostJNI.RabitFinalize());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -55,7 +55,7 @@ public class Rabit {
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void trackerPrint(String msg) throws XGBoostError {
|
||||
checkCall(XgboostJNI.RabitTrackerPrint(msg));
|
||||
checkCall(XGBoostJNI.RabitTrackerPrint(msg));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -66,7 +66,7 @@ public class Rabit {
|
||||
*/
|
||||
public static int versionNumber() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitVersionNumber(out));
|
||||
checkCall(XGBoostJNI.RabitVersionNumber(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ public class Rabit {
|
||||
*/
|
||||
public static int getRank() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitGetRank(out));
|
||||
checkCall(XGBoostJNI.RabitGetRank(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ public class Rabit {
|
||||
*/
|
||||
public static int getWorldSize() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitGetWorldSize(out));
|
||||
checkCall(XGBoostJNI.RabitGetWorldSize(out));
|
||||
return out[0];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ package ml.dmlc.xgboost4j.java;
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
class XgboostJNI {
|
||||
class XGBoostJNI {
|
||||
public final static native String XGBGetLastError();
|
||||
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
@@ -21,7 +21,6 @@ import _root_.scala.collection.JavaConverters._
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError}
|
||||
|
||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* init DMatrix from file (svmlight format)
|
||||
*
|
||||
|
||||
@@ -16,10 +16,11 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java
|
||||
import ml.dmlc.xgboost4j.java.IObjective
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.java.IObjective
|
||||
|
||||
trait ObjectiveTrait extends IObjective {
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
@@ -30,7 +31,7 @@ trait ObjectiveTrait extends IObjective {
|
||||
*/
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]]
|
||||
|
||||
private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: java.DMatrix):
|
||||
private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
|
||||
java.util.List[Array[Float]] = {
|
||||
getGradient(predicts, new DMatrix(dtrain)).asJava
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
object XGBoost {
|
||||
|
||||
Reference in New Issue
Block a user