[jvm-packages] support missing value when constructing dmatrix with iterator (#10628)

This commit is contained in:
Bobby Wang 2024-07-23 23:25:07 +08:00 committed by GitHub
parent b3ed81877a
commit 7949a8d5f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 300 additions and 162 deletions

View File

@ -402,6 +402,7 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*)
* \param data_handle The handle to the data. * \param data_handle The handle to the data.
* \param callback The callback to get the data. * \param callback The callback to get the data.
* \param cache_info Additional information about cache file, can be null. * \param cache_info Additional information about cache file, can be null.
* \param missing Which value to represent missing value.
* \param out The created DMatrix * \param out The created DMatrix
* \return 0 when success, -1 when failure happens. * \return 0 when success, -1 when failure happens.
*/ */
@ -409,6 +410,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
DataIterHandle data_handle, DataIterHandle data_handle,
XGBCallbackDataIterNext* callback, XGBCallbackDataIterNext* callback,
const char* cache_info, const char* cache_info,
float missing,
DMatrixHandle *out); DMatrixHandle *out);
/** /**

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014-2023 by Contributors Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -28,14 +28,6 @@ import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
public class DMatrix { public class DMatrix {
protected long handle = 0; protected long handle = 0;
/**
* sparse matrix type (CSR or CSC)
*/
public static enum SparseType {
CSR,
CSC;
}
/** /**
* Create DMatrix from iterator. * Create DMatrix from iterator.
* *
@ -44,6 +36,20 @@ public class DMatrix {
* @throws XGBoostError * @throws XGBoostError
*/ */
public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError { public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
this(iter, cacheInfo, Float.NaN);
}
/**
* Create DMatrix from iterator.
*
* @param iter The data iterator of mini batch to provide the data.
* @param cacheInfo Cache path information, used for external memory setting, can be null.
* @param missing the missing value
* @throws XGBoostError
*/
public DMatrix(Iterator<LabeledPoint> iter,
String cacheInfo,
float missing) throws XGBoostError {
if (iter == null) { if (iter == null) {
throw new NullPointerException("iter: null"); throw new NullPointerException("iter: null");
} }
@ -51,7 +57,8 @@ public class DMatrix {
int batchSize = 32 << 10; int batchSize = 32 << 10;
Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize); Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize);
long[] out = new long[1]; long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out)); XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(
batchIter, cacheInfo, missing, out));
handle = out[0]; handle = out[0];
} }
@ -72,6 +79,7 @@ public class DMatrix {
/** /**
* Create DMatrix from Sparse matrix in CSR/CSC format. * Create DMatrix from Sparse matrix in CSR/CSC format.
*
* @param headers The row index of the matrix. * @param headers The row index of the matrix.
* @param indices The indices of presenting entries. * @param indices The indices of presenting entries.
* @param data The data content. * @param data The data content.
@ -86,6 +94,7 @@ public class DMatrix {
/** /**
* Create DMatrix from Sparse matrix in CSR/CSC format. * Create DMatrix from Sparse matrix in CSR/CSC format.
*
* @param headers The row index of the matrix. * @param headers The row index of the matrix.
* @param indices The indices of presenting entries. * @param indices The indices of presenting entries.
* @param data The data content. * @param data The data content.
@ -121,7 +130,6 @@ public class DMatrix {
* @param nrow number of rows * @param nrow number of rows
* @param ncol number of columns * @param ncol number of columns
* @throws XGBoostError native error * @throws XGBoostError native error
*
* @deprecated Please specify the missing value explicitly using * @deprecated Please specify the missing value explicitly using
* {@link DMatrix(float[], int, int, float)} * {@link DMatrix(float[], int, int, float)}
*/ */
@ -144,6 +152,7 @@ public class DMatrix {
/** /**
* create DMatrix from dense matrix * create DMatrix from dense matrix
*
* @param data data values * @param data data values
* @param nrow number of rows * @param nrow number of rows
* @param ncol number of columns * @param ncol number of columns
@ -157,6 +166,7 @@ public class DMatrix {
/** /**
* create DMatrix from dense matrix * create DMatrix from dense matrix
*
* @param matrix instance of BigDenseMatrix * @param matrix instance of BigDenseMatrix
* @param missing the specified value to represent the missing value * @param missing the specified value to represent the missing value
*/ */
@ -176,7 +186,8 @@ public class DMatrix {
/** /**
* Create the normal DMatrix from column array interface * Create the normal DMatrix from column array interface
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface *
* @param columnBatch the XGBoost ColumnBatch to provide the array interface
* of feature columns * of feature columns
* @param missing missing value * @param missing missing value
* @param nthread threads number * @param nthread threads number
@ -194,36 +205,30 @@ public class DMatrix {
} }
/** /**
* Set label of DMatrix from cuda array interface * flatten a mat to array
*
* @param column the XGBoost Column to provide the cuda array interface
* of label column
* @throws XGBoostError native error
*/ */
public void setLabel(Column column) throws XGBoostError { private static float[] flatten(float[][] mat) {
setXGBDMatrixInfo("label", column.getArrayInterfaceJson()); int size = 0;
for (float[] array : mat) size += array.length;
float[] result = new float[size];
int pos = 0;
for (float[] ar : mat) {
System.arraycopy(ar, 0, result, pos, ar.length);
pos += ar.length;
}
return result;
} }
/** /**
* Set weight of DMatrix from cuda array interface * Set query id of DMatrix from array interface
* *
* @param column the XGBoost Column to provide the cuda array interface * @param column the XGBoost Column to provide the array interface
* of weight column * of query id column
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public void setWeight(Column column) throws XGBoostError { public void setQueryId(Column column) throws XGBoostError {
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson()); setXGBDMatrixInfo("qid", column.getArrayInterfaceJson());
}
/**
* Set base margin of DMatrix from cuda array interface
*
* @param column the XGBoost Column to provide the cuda array interface
* of base margin column
* @throws XGBoostError native error
*/
public void setBaseMargin(Column column) throws XGBoostError {
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
} }
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError { private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
@ -257,17 +262,9 @@ public class DMatrix {
return outValue[0]; return outValue[0];
} }
/**
* Set feature names
* @param values feature names to be set
* @throws XGBoostError
*/
public void setFeatureNames(String[] values) throws XGBoostError {
setXGBDMatrixFeatureInfo("feature_name", values);
}
/** /**
* Get feature names * Get feature names
*
* @return an array of feature names to be returned * @return an array of feature names to be returned
* @throws XGBoostError * @throws XGBoostError
*/ */
@ -276,16 +273,18 @@ public class DMatrix {
} }
/** /**
* Set feature types * Set feature names
* @param values feature types to be set *
* @param values feature names to be set
* @throws XGBoostError * @throws XGBoostError
*/ */
public void setFeatureTypes(String[] values) throws XGBoostError { public void setFeatureNames(String[] values) throws XGBoostError {
setXGBDMatrixFeatureInfo("feature_type", values); setXGBDMatrixFeatureInfo("feature_name", values);
} }
/** /**
* Get feature types * Get feature types
*
* @return an array of feature types to be returned * @return an array of feature types to be returned
* @throws XGBoostError * @throws XGBoostError
*/ */
@ -294,46 +293,23 @@ public class DMatrix {
} }
/** /**
* set label of dmatrix * Set feature types
* *
* @param labels labels * @param values feature types to be set
* @throws XGBoostError
*/
public void setFeatureTypes(String[] values) throws XGBoostError {
setXGBDMatrixFeatureInfo("feature_type", values);
}
/**
* Get group sizes of DMatrix
*
* @return group size as array
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public void setLabel(float[] labels) throws XGBoostError { public int[] getGroup() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels)); return getIntInfo("group_ptr");
}
/**
* set weight of each instance
*
* @param weights weights
* @throws XGBoostError native error
*/
public void setWeight(float[] weights) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
}
/**
* Set base margin (initial prediction).
*
* The margin must have the same number of elements as the number of
* rows in this matrix.
*/
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
if (baseMargin.length != rowNum()) {
throw new IllegalArgumentException(String.format(
"base margin must have exactly %s elements, got %s",
rowNum(), baseMargin.length));
}
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
}
/**
* Set base margin (initial prediction).
*/
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
setBaseMargin(flatten(baseMargin));
} }
/** /**
@ -347,13 +323,13 @@ public class DMatrix {
} }
/** /**
* Get group sizes of DMatrix * Set query ids (used for ranking)
* *
* @param qid the query ids
* @throws XGBoostError native error * @throws XGBoostError native error
* @return group size as array
*/ */
public int[] getGroup() throws XGBoostError { public void setQueryId(int[] qid) throws XGBoostError {
return getIntInfo("group_ptr"); XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetUIntInfo(handle, "qid", qid));
} }
private float[] getFloatInfo(String field) throws XGBoostError { private float[] getFloatInfo(String field) throws XGBoostError {
@ -378,6 +354,27 @@ public class DMatrix {
return getFloatInfo("label"); return getFloatInfo("label");
} }
/**
* Set label of DMatrix from array interface
*
* @param column the XGBoost Column to provide the array interface
* of label column
* @throws XGBoostError native error
*/
public void setLabel(Column column) throws XGBoostError {
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
}
/**
* set label of dmatrix
*
* @param labels labels
* @throws XGBoostError native error
*/
public void setLabel(float[] labels) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
}
/** /**
* get weight of the DMatrix * get weight of the DMatrix
* *
@ -388,6 +385,27 @@ public class DMatrix {
return getFloatInfo("weight"); return getFloatInfo("weight");
} }
/**
* Set weight of DMatrix from array interface
*
* @param column the XGBoost Column to provide the array interface
* of weight column
* @throws XGBoostError native error
*/
public void setWeight(Column column) throws XGBoostError {
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
}
/**
* set weight of each instance
*
* @param weights weights
* @throws XGBoostError native error
*/
public void setWeight(float[] weights) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
}
/** /**
* Get base margin of the DMatrix. * Get base margin of the DMatrix.
*/ */
@ -395,6 +413,40 @@ public class DMatrix {
return getFloatInfo("base_margin"); return getFloatInfo("base_margin");
} }
/**
* Set base margin of DMatrix from array interface
*
* @param column the XGBoost Column to provide the array interface
* of base margin column
* @throws XGBoostError native error
*/
public void setBaseMargin(Column column) throws XGBoostError {
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
}
/**
* Set base margin (initial prediction).
* <p>
* The margin must have the same number of elements as the number of
* rows in this matrix.
*/
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
if (baseMargin.length != rowNum()) {
throw new IllegalArgumentException(String.format(
"base margin must have exactly %s elements, got %s",
rowNum(), baseMargin.length));
}
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
}
/**
* Set base margin (initial prediction).
*/
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
setBaseMargin(flatten(baseMargin));
}
/** /**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
* *
@ -448,22 +500,6 @@ public class DMatrix {
return handle; return handle;
} }
/**
* flatten a mat to array
*/
private static float[] flatten(float[][] mat) {
int size = 0;
for (float[] array : mat) size += array.length;
float[] result = new float[size];
int pos = 0;
for (float[] ar : mat) {
System.arraycopy(ar, 0, result, pos, ar.length);
pos += ar.length;
}
return result;
}
@Override @Override
protected void finalize() { protected void finalize() {
dispose(); dispose();
@ -475,4 +511,12 @@ public class DMatrix {
handle = 0; handle = 0;
} }
} }
/**
* sparse matrix type (CSR or CSC)
*/
public enum SparseType {
CSR,
CSC
}
} }

View File

@ -54,7 +54,7 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out); String cache_info, float missing, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
float[] data, int shapeParam, float[] data, int shapeParam,

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014-2023 by Contributors Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._ import _root_.scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.LabeledPoint import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DataBatch, XGBoostError, DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DMatrix => JDMatrix, XGBoostError}
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
/** /**
@ -37,10 +37,13 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
* *
* @param dataIter An iterator of LabeledPoint * @param dataIter An iterator of LabeledPoint
* @param cacheInfo Cache path information, used for external memory setting, null by default. * @param cacheInfo Cache path information, used for external memory setting, null by default.
* @param missing Which value will be treated as the missing value
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
def this(dataIter: Iterator[LabeledPoint], cacheInfo: String = null) { def this(dataIter: Iterator[LabeledPoint],
this(new JDMatrix(dataIter.asJava, cacheInfo)) cacheInfo: String = null,
missing: Float = Float.NaN) {
this(new JDMatrix(dataIter.asJava, cacheInfo, missing))
} }
/** /**
@ -93,6 +96,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
/** /**
* Create the normal DMatrix from column array interface * Create the normal DMatrix from column array interface
*
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface * @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
* of feature columns * of feature columns
* @param missing missing value * @param missing missing value
@ -181,6 +185,16 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.setGroup(group) jDMatrix.setGroup(group)
} }
/**
* Set query ids (used for ranking)
*
* @param qid query ids
*/
@throws(classOf[XGBoostError])
def setQueryId(qid: Array[Int]): Unit = {
jDMatrix.setQueryId(qid)
}
/** /**
* Set label of DMatrix from cuda array interface * Set label of DMatrix from cuda array interface
*/ */
@ -205,8 +219,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.setBaseMargin(column) jDMatrix.setBaseMargin(column)
} }
/**
* set query id of dmatrix from column array interface
*/
@throws(classOf[XGBoostError])
def setQueryId(column: Column): Unit = {
jDMatrix.setQueryId(column)
}
/** /**
* set feature names * set feature names
*
* @param values feature names * @param values feature names
* @throws ml.dmlc.xgboost4j.java.XGBoostError * @throws ml.dmlc.xgboost4j.java.XGBoostError
*/ */
@ -217,6 +240,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
/** /**
* set feature types * set feature types
*
* @param values feature types * @param values feature types
* @throws ml.dmlc.xgboost4j.java.XGBoostError * @throws ml.dmlc.xgboost4j.java.XGBoostError
*/ */
@ -265,6 +289,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
/** /**
* get feature names * get feature names
*
* @throws ml.dmlc.xgboost4j.java.XGBoostError * @throws ml.dmlc.xgboost4j.java.XGBoostError
* @return * @return
*/ */
@ -275,6 +300,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
/** /**
* get feature types * get feature types
*
* @throws ml.dmlc.xgboost4j.java.XGBoostError * @throws ml.dmlc.xgboost4j.java.XGBoostError
* @return * @return
*/ */

View File

@ -214,7 +214,7 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I * Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jfloat jmissing, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
std::unique_ptr<char const, Deleter<char const>> cache_info; std::unique_ptr<char const, Deleter<char const>> cache_info;
if (jcache_info != nullptr) { if (jcache_info != nullptr) {
@ -222,8 +222,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
jenv->ReleaseStringUTFChars(jcache_info, ptr); jenv->ReleaseStringUTFChars(jcache_info, ptr);
}}; }};
} }
auto missing = static_cast<float>(jmissing);
int ret = int ret =
XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(), &result); XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(),
missing,&result);
JVM_CHECK_CALL(ret); JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
return ret; return ret;

View File

@ -26,10 +26,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromDataIter * Method: XGDMatrixCreateFromDataIter
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I * Signature: (Ljava/util/Iterator;Ljava/lang/String;F[J)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter
(JNIEnv *, jclass, jobject, jstring, jlongArray); (JNIEnv *, jclass, jobject, jstring, jfloat, jlongArray);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI

View File

@ -15,15 +15,18 @@
*/ */
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.io.*; import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import junit.framework.TestCase; import junit.framework.TestCase;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
@ -36,6 +39,32 @@ import static org.junit.Assert.assertEquals;
*/ */
public class DMatrixTest { public class DMatrixTest {
@Test
public void testCreateFromDataIteratorWithMissingValue() throws XGBoostError {
//create DMatrix from DataIterator
java.util.List<LabeledPoint> blist = new java.util.LinkedList<>();
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{1, 0, 0, 0}));
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{Float.NaN, 13, 14, 15}));
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{21, 23, 0, 25}));
// Default missing value: Float.NaN
DMatrix dmat = new DMatrix(blist.iterator(), null);
assert dmat.nonMissingNum() == 11;
// missing value 0
dmat = new DMatrix(blist.iterator(), null, 0.0f);
assert dmat.nonMissingNum() == 12 - 4 - 1;
// missing value 21
dmat = new DMatrix(blist.iterator(), null, 21.0f);
assert dmat.nonMissingNum() == 12 - 1 - 1;
// missing value 101010101010
dmat = new DMatrix(blist.iterator(), null, 101010101010.0f);
assert dmat.nonMissingNum() == 12 - 1;
}
@Test @Test
public void testCreateFromDataIterator() throws XGBoostError { public void testCreateFromDataIterator() throws XGBoostError {
//create DMatrix from DataIterator //create DMatrix from DataIterator
@ -428,4 +457,40 @@ public class DMatrixTest {
String[] retFeatureTypes = dmat.getFeatureTypes(); String[] retFeatureTypes = dmat.getFeatureTypes();
assertArrayEquals(featureTypes, retFeatureTypes); assertArrayEquals(featureTypes, retFeatureTypes);
} }
@Test
public void testSetAndGetQueryId() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data0 = new float[nrow * ncol];
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
data0[i] = random.nextFloat();
}
//create label
float[] label0 = new float[nrow];
for (int i = 0; i < nrow; i++) {
label0[i] = random.nextFloat();
}
//create two groups
int[] qid = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
int[] qidExpected = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
dmat0.setLabel(label0);
dmat0.setQueryId(qid);
//check
TestCase.assertTrue(Arrays.equals(qidExpected, dmat0.getGroup()));
//create two groups
int[] qid1 = new int[]{10, 10, 10, 20, 60, 60, 80, 80, 90, 100};
int[] qidExpected1 = new int[]{0, 3, 4, 6, 8, 9, 10};
dmat0.setQueryId(qid1);
TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup()));
}
} }

View File

@ -253,7 +253,9 @@ XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) {
XGB_DLL int XGDMatrixCreateFromDataIter( XGB_DLL int XGDMatrixCreateFromDataIter(
void *data_handle, // a Java iterator void *data_handle, // a Java iterator
XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp
const char *cache_info, DMatrixHandle *out) { const char *cache_info,
float missing,
DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
std::string scache; std::string scache;
@ -264,10 +266,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
data_handle, callback); data_handle, callback);
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix> { *out = new std::shared_ptr<DMatrix> {
DMatrix::Create( DMatrix::Create(&adapter, missing, 1, scache)
&adapter, std::numeric_limits<float>::quiet_NaN(),
1, scache
)
}; };
API_END(); API_END();
} }