[jvm-packages] Add BigDenseMatrix (#4383)

* Add BigDenseMatrix

* ability to create DMatrix with bigger than Integer.MAX_VALUE size arrays
* uses sun.misc.Unsafe

* make DMatrix test work from a jar as well
This commit is contained in:
Honza Sterba
2019-09-19 05:46:14 +02:00
committed by Philip Hyunsu Cho
parent 57106a3459
commit 22209b7b95
8 changed files with 300 additions and 1 deletions

View File

@@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
/**
* DMatrix for xgboost.
@@ -130,6 +131,16 @@ public class DMatrix {
handle = out[0];
}
/**
* create DMatrix from a BigDenseMatrix
*
* @param matrix instance of BigDenseMatrix
* @throws XGBoostError native error
*/
public DMatrix(BigDenseMatrix matrix) throws XGBoostError {
this(matrix, 0.0f);
}
/**
* create DMatrix from dense matrix
* @param data data values
@@ -143,6 +154,18 @@ public class DMatrix {
handle = out[0];
}
/**
* create DMatrix from dense matrix
* @param matrix instance of BigDenseMatrix
* @param missing the specified value to represent the missing value
*/
public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError {
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow,
matrix.ncol, missing, out));
handle = out[0];
}
/**
* used for DMatrix slice
*/

View File

@@ -65,6 +65,9 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol,
float missing, long[] out);
public final static native int XGDMatrixCreateFromMatRef(long dataRef, int nrow, int ncol,
float missing, long[] out);
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
public final static native int XGDMatrixFree(long handle);

View File

@@ -0,0 +1,76 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.util;
/**
* Off-heap implementation of a Dense Matrix, matrix size is only limited by the
* amount of the available memory and the matrix dimension cannot exceed
* Integer.MAX_VALUE (this is consistent with XGBoost API restrictions on maximum
* length of a response).
*/
public final class BigDenseMatrix {
private static final int FLOAT_BYTE_SIZE = 4;
public static final long MAX_MATRIX_SIZE = Long.MAX_VALUE / FLOAT_BYTE_SIZE;
public final int nrow;
public final int ncol;
public final long address;
public static void setDirect(long valAddress, float val) {
UtilUnsafe.UNSAFE.putFloat(valAddress, val);
}
public static float getDirect(long valAddress) {
return UtilUnsafe.UNSAFE.getFloat(valAddress);
}
public BigDenseMatrix(int nrow, int ncol) {
final long size = (long) nrow * ncol;
if (size > MAX_MATRIX_SIZE) {
throw new IllegalArgumentException("Matrix too large; matrix size cannot exceed " +
MAX_MATRIX_SIZE);
}
this.nrow = nrow;
this.ncol = ncol;
this.address = UtilUnsafe.UNSAFE.allocateMemory(size * FLOAT_BYTE_SIZE);
}
public final void set(long idx, float val) {
setDirect(address + idx * FLOAT_BYTE_SIZE, val);
}
public final void set(int i, int j, float val) {
set(index(i, j), val);
}
public final float get(long idx) {
return getDirect(address + idx * FLOAT_BYTE_SIZE);
}
public final float get(int i, int j) {
return get(index(i, j));
}
public final void dispose() {
UtilUnsafe.UNSAFE.freeMemory(address);
}
private long index(int i, int j) {
return (long) i * ncol + j;
}
}

View File

@@ -0,0 +1,46 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.util;
import java.lang.reflect.Field;
import sun.misc.Unsafe;
/**
* Simple class to obtain access to the {@link Unsafe} object. Use responsibly :)
*/
public final class UtilUnsafe {
static Unsafe UNSAFE = getUnsafe();
private UtilUnsafe() {
} // dummy private constructor
private static Unsafe getUnsafe() {
// Not on bootclasspath
if (UtilUnsafe.class.getClassLoader() == null) {
return Unsafe.getUnsafe();
}
try {
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
fld.setAccessible(true);
return (Unsafe) fld.get(UtilUnsafe.class);
} catch (Exception e) {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}