[jvm-packages] Add BigDenseMatrix (#4383)
* Add BigDenseMatrix * ability to create DMatrix with bigger than Integer.MAX_VALUE size arrays * uses sun.misc.Unsafe * make DMatrix test work from a jar as well
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
57106a3459
commit
22209b7b95
@@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.java;
|
||||
import java.util.Iterator;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
|
||||
|
||||
/**
|
||||
* DMatrix for xgboost.
|
||||
@@ -130,6 +131,16 @@ public class DMatrix {
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from a BigDenseMatrix
|
||||
*
|
||||
* @param matrix instance of BigDenseMatrix
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public DMatrix(BigDenseMatrix matrix) throws XGBoostError {
|
||||
this(matrix, 0.0f);
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
* @param data data values
|
||||
@@ -143,6 +154,18 @@ public class DMatrix {
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
* @param matrix instance of BigDenseMatrix
|
||||
* @param missing the specified value to represent the missing value
|
||||
*/
|
||||
public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow,
|
||||
matrix.ncol, missing, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* used for DMatrix slice
|
||||
*/
|
||||
|
||||
@@ -65,6 +65,9 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
|
||||
float missing, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromMatRef(long dataRef, int nrow, int ncol,
|
||||
float missing, long[] out);
|
||||
|
||||
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
|
||||
|
||||
public final static native int XGDMatrixFree(long handle);
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.util;
|
||||
|
||||
/**
|
||||
* Off-heap implementation of a Dense Matrix, matrix size is only limited by the
|
||||
* amount of the available memory and the matrix dimension cannot exceed
|
||||
* Integer.MAX_VALUE (this is consistent with XGBoost API restrictions on maximum
|
||||
* length of a response).
|
||||
*/
|
||||
public final class BigDenseMatrix {
|
||||
|
||||
private static final int FLOAT_BYTE_SIZE = 4;
|
||||
public static final long MAX_MATRIX_SIZE = Long.MAX_VALUE / FLOAT_BYTE_SIZE;
|
||||
|
||||
public final int nrow;
|
||||
public final int ncol;
|
||||
public final long address;
|
||||
|
||||
public static void setDirect(long valAddress, float val) {
|
||||
UtilUnsafe.UNSAFE.putFloat(valAddress, val);
|
||||
}
|
||||
|
||||
public static float getDirect(long valAddress) {
|
||||
return UtilUnsafe.UNSAFE.getFloat(valAddress);
|
||||
}
|
||||
|
||||
public BigDenseMatrix(int nrow, int ncol) {
|
||||
final long size = (long) nrow * ncol;
|
||||
if (size > MAX_MATRIX_SIZE) {
|
||||
throw new IllegalArgumentException("Matrix too large; matrix size cannot exceed " +
|
||||
MAX_MATRIX_SIZE);
|
||||
}
|
||||
this.nrow = nrow;
|
||||
this.ncol = ncol;
|
||||
this.address = UtilUnsafe.UNSAFE.allocateMemory(size * FLOAT_BYTE_SIZE);
|
||||
}
|
||||
|
||||
public final void set(long idx, float val) {
|
||||
setDirect(address + idx * FLOAT_BYTE_SIZE, val);
|
||||
}
|
||||
|
||||
public final void set(int i, int j, float val) {
|
||||
set(index(i, j), val);
|
||||
}
|
||||
|
||||
public final float get(long idx) {
|
||||
return getDirect(address + idx * FLOAT_BYTE_SIZE);
|
||||
}
|
||||
|
||||
public final float get(int i, int j) {
|
||||
return get(index(i, j));
|
||||
}
|
||||
|
||||
public final void dispose() {
|
||||
UtilUnsafe.UNSAFE.freeMemory(address);
|
||||
}
|
||||
|
||||
private long index(int i, int j) {
|
||||
return (long) i * ncol + j;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.util;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
|
||||
import sun.misc.Unsafe;
|
||||
|
||||
/**
|
||||
* Simple class to obtain access to the {@link Unsafe} object. Use responsibly :)
|
||||
*/
|
||||
public final class UtilUnsafe {
|
||||
|
||||
static Unsafe UNSAFE = getUnsafe();
|
||||
|
||||
private UtilUnsafe() {
|
||||
} // dummy private constructor
|
||||
|
||||
private static Unsafe getUnsafe() {
|
||||
// Not on bootclasspath
|
||||
if (UtilUnsafe.class.getClassLoader() == null) {
|
||||
return Unsafe.getUnsafe();
|
||||
}
|
||||
try {
|
||||
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
|
||||
fld.setAccessible(true);
|
||||
return (Unsafe) fld.get(UtilUnsafe.class);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user