[jvm-packages]add feature size for LabelPoint and DataBatch (#5303)
* fix type error * Validate number of features. * resolve comments * add feature size for LabelPoint and DataBatch * pass the feature size to native * move feature size validating tests into a separate suite * resolve comments Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
@@ -49,7 +49,6 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
*/
|
||||
Booster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
|
||||
init(cacheMats);
|
||||
setParam("validate_features", "0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
@@ -27,14 +27,17 @@ class DataBatch {
|
||||
final int[] featureIndex;
|
||||
/** value of each non-missing entry in the sparse matrix */
|
||||
final float[] featureValue ;
|
||||
/** feature columns */
|
||||
final int featureCols;
|
||||
|
||||
DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
|
||||
float[] featureValue) {
|
||||
float[] featureValue, int featureCols) {
|
||||
this.rowOffset = rowOffset;
|
||||
this.weight = weight;
|
||||
this.label = label;
|
||||
this.featureIndex = featureIndex;
|
||||
this.featureValue = featureValue;
|
||||
this.featureCols = featureCols;
|
||||
}
|
||||
|
||||
static class BatchIterator implements Iterator<DataBatch> {
|
||||
@@ -56,9 +59,15 @@ class DataBatch {
|
||||
try {
|
||||
int numRows = 0;
|
||||
int numElem = 0;
|
||||
int numCol = -1;
|
||||
List<LabeledPoint> batch = new ArrayList<>(batchSize);
|
||||
while (base.hasNext() && batch.size() < batchSize) {
|
||||
LabeledPoint labeledPoint = base.next();
|
||||
if (numCol == -1) {
|
||||
numCol = labeledPoint.size();
|
||||
} else if (numCol != labeledPoint.size()) {
|
||||
throw new RuntimeException("Feature size is not the same");
|
||||
}
|
||||
batch.add(labeledPoint);
|
||||
numElem += labeledPoint.values().length;
|
||||
numRows++;
|
||||
@@ -91,7 +100,7 @@ class DataBatch {
|
||||
}
|
||||
|
||||
rowOffset[batch.size()] = offset;
|
||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
|
||||
return new DataBatch(rowOffset, weight, label, featureIndex, featureValue, numCol);
|
||||
} catch (RuntimeException runtimeError) {
|
||||
logger.error(runtimeError);
|
||||
return null;
|
||||
|
||||
@@ -20,6 +20,7 @@ package ml.dmlc.xgboost4j
|
||||
* Labeled training data point.
|
||||
*
|
||||
* @param label Label of this point.
|
||||
* @param size Feature dimensionality
|
||||
* @param indices Feature indices of this point or `null` if the data is dense.
|
||||
* @param values Feature values of this point.
|
||||
* @param weight Weight of this point.
|
||||
@@ -28,6 +29,7 @@ package ml.dmlc.xgboost4j
|
||||
*/
|
||||
case class LabeledPoint(
|
||||
label: Float,
|
||||
size: Int,
|
||||
indices: Array[Int],
|
||||
values: Array[Float],
|
||||
weight: Float = 1f,
|
||||
@@ -36,8 +38,11 @@ case class LabeledPoint(
|
||||
require(indices == null || indices.length == values.length,
|
||||
"indices and values must have the same number of elements")
|
||||
|
||||
def this(label: Float, indices: Array[Int], values: Array[Float]) = {
|
||||
require(indices == null || size >= indices.length,
|
||||
"feature dimensionality must be greater equal than size of indices")
|
||||
|
||||
def this(label: Float, size: Int, indices: Array[Int], values: Array[Float]) = {
|
||||
// [[weight]] default duplicated to disambiguate the constructor call.
|
||||
this(label, indices, values, 1.0f)
|
||||
this(label, size, indices, values, 1.0f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,9 +91,11 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
batch, jenv->GetFieldID(batchClass, "featureIndex", "[I"));
|
||||
jfloatArray jvalue = (jfloatArray)jenv->GetObjectField(
|
||||
batch, jenv->GetFieldID(batchClass, "featureValue", "[F"));
|
||||
jint jcols = jenv->GetIntField(
|
||||
batch, jenv->GetFieldID(batchClass, "featureCols", "I"));
|
||||
XGBoostBatchCSR cbatch;
|
||||
cbatch.size = jenv->GetArrayLength(joffset) - 1;
|
||||
cbatch.columns = std::numeric_limits<size_t>::max();
|
||||
cbatch.columns = jcols;
|
||||
cbatch.offset = reinterpret_cast<jlong *>(
|
||||
jenv->GetLongArrayElements(joffset, 0));
|
||||
if (jlabel != nullptr) {
|
||||
|
||||
@@ -45,7 +45,7 @@ public class DMatrixTest {
|
||||
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||
for (int i = 0; i < nrep; ++i) {
|
||||
LabeledPoint p = new LabeledPoint(
|
||||
0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||
blist.add(p);
|
||||
labelall.add(p.label());
|
||||
}
|
||||
@@ -57,6 +57,33 @@ public class DMatrixTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDataIteratorWithDiffFeatureSize() throws XGBoostError {
|
||||
//create DMatrix from DataIterator
|
||||
|
||||
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
||||
int nrep = 3000;
|
||||
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||
int featureSize = 4;
|
||||
for (int i = 0; i < nrep; ++i) {
|
||||
// set some rows with wrong feature size
|
||||
if (i % 10 == 1) {
|
||||
featureSize = 5;
|
||||
}
|
||||
LabeledPoint p = new LabeledPoint(
|
||||
0.1f + i, featureSize, new int[]{0, 2, 3}, new float[]{3, 4, 5});
|
||||
blist.add(p);
|
||||
labelall.add(p.label());
|
||||
}
|
||||
boolean success = true;
|
||||
try {
|
||||
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||
} catch (XGBoostError e) {
|
||||
success = false;
|
||||
}
|
||||
TestCase.assertTrue(success == false);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromFile() throws XGBoostError {
|
||||
//create DMatrix from file
|
||||
|
||||
Reference in New Issue
Block a user