[jvm-packages]add feature size for LabelPoint and DataBatch (#5303)

* fix type error

* Validate number of features.

* resolve comments

* add feature size for LabelPoint and DataBatch

* pass the feature size to native

* move feature size validating tests into a separate suite

* resolve comments

Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2020-04-08 07:49:52 +08:00
committed by GitHub
parent 8bc595ea1e
commit ad826e913f
17 changed files with 193 additions and 75 deletions

View File

@@ -49,7 +49,6 @@ public class Booster implements Serializable, KryoSerializable {
*/
Booster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
init(cacheMats);
setParam("validate_features", "0");
setParams(params);
}

View File

@@ -27,14 +27,17 @@ class DataBatch {
final int[] featureIndex;
/** value of each non-missing entry in the sparse matrix */
final float[] featureValue ;
/** feature columns */
final int featureCols;
DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
float[] featureValue) {
float[] featureValue, int featureCols) {
this.rowOffset = rowOffset;
this.weight = weight;
this.label = label;
this.featureIndex = featureIndex;
this.featureValue = featureValue;
this.featureCols = featureCols;
}
static class BatchIterator implements Iterator<DataBatch> {
@@ -56,9 +59,15 @@ class DataBatch {
try {
int numRows = 0;
int numElem = 0;
int numCol = -1;
List<LabeledPoint> batch = new ArrayList<>(batchSize);
while (base.hasNext() && batch.size() < batchSize) {
LabeledPoint labeledPoint = base.next();
if (numCol == -1) {
numCol = labeledPoint.size();
} else if (numCol != labeledPoint.size()) {
throw new RuntimeException("Feature size is not the same");
}
batch.add(labeledPoint);
numElem += labeledPoint.values().length;
numRows++;
@@ -91,7 +100,7 @@ class DataBatch {
}
rowOffset[batch.size()] = offset;
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue, numCol);
} catch (RuntimeException runtimeError) {
logger.error(runtimeError);
return null;

View File

@@ -20,6 +20,7 @@ package ml.dmlc.xgboost4j
* Labeled training data point.
*
* @param label Label of this point.
* @param size Feature dimensionality
* @param indices Feature indices of this point or `null` if the data is dense.
* @param values Feature values of this point.
* @param weight Weight of this point.
@@ -28,6 +29,7 @@ package ml.dmlc.xgboost4j
*/
case class LabeledPoint(
label: Float,
size: Int,
indices: Array[Int],
values: Array[Float],
weight: Float = 1f,
@@ -36,8 +38,11 @@ case class LabeledPoint(
require(indices == null || indices.length == values.length,
"indices and values must have the same number of elements")
def this(label: Float, indices: Array[Int], values: Array[Float]) = {
require(indices == null || size >= indices.length,
"feature dimensionality must be greater equal than size of indices")
def this(label: Float, size: Int, indices: Array[Int], values: Array[Float]) = {
// [[weight]] default duplicated to disambiguate the constructor call.
this(label, indices, values, 1.0f)
this(label, size, indices, values, 1.0f)
}
}

View File

@@ -91,9 +91,11 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
jint jcols = jenv->GetIntField(
batch, jenv->GetFieldID(batchClass, "featureCols", "I"));
XGBoostBatchCSR cbatch;
cbatch.size = jenv->GetArrayLength(joffset) - 1;
cbatch.columns = std::numeric_limits<size_t>::max();
cbatch.columns = jcols;
cbatch.offset = reinterpret_cast<jlong *>(
jenv->GetLongArrayElements(joffset, 0));
if (jlabel != nullptr) {

View File

@@ -45,7 +45,7 @@ public class DMatrixTest {
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
for (int i = 0; i < nrep; ++i) {
LabeledPoint p = new LabeledPoint(
0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5});
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
blist.add(p);
labelall.add(p.label());
}
@@ -57,6 +57,33 @@ public class DMatrixTest {
}
}
@Test
public void testCreateFromDataIteratorWithDiffFeatureSize() throws XGBoostError {
//create DMatrix from DataIterator
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
int nrep = 3000;
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
int featureSize = 4;
for (int i = 0; i < nrep; ++i) {
// set some rows with wrong feature size
if (i % 10 == 1) {
featureSize = 5;
}
LabeledPoint p = new LabeledPoint(
0.1f + i, featureSize, new int[]{0, 2, 3}, new float[]{3, 4, 5});
blist.add(p);
labelall.add(p.label());
}
boolean success = true;
try {
DMatrix dmat = new DMatrix(blist.iterator(), null);
} catch (XGBoostError e) {
success = false;
}
TestCase.assertTrue(success == false);
}
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file