[breakinig] [jvm-packages] change DeviceQuantileDmatrix into QuantileDMatrix (#8461)

This commit is contained in:
Bobby Wang
2022-12-05 12:23:21 +08:00
committed by GitHub
parent 78d65a1928
commit f1e9bbcee5
13 changed files with 150 additions and 94 deletions

View File

@@ -20,6 +20,13 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
common::AssertGPUSupport();
API_END();
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
char const *config, jlongArray jout) {
API_BEGIN();
common::AssertGPUSupport();
API_END();
}
} // namespace jni
} // namespace xgboost
#endif // XGBOOST_USE_CUDA

View File

@@ -379,7 +379,7 @@ int Next(DataIterHandle self) {
}
} // anonymous namespace
XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
@@ -392,5 +392,21 @@ XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass
setHandle(jenv, jout, result);
return ret;
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
char const *config, jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
DMatrixHandle result;
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
if (jref_iter) {
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
}
auto ret = XGQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
setHandle(jenv, jout, result);
return ret;
}
} // namespace jni
} // namespace xgboost

View File

@@ -34,7 +34,7 @@ import ai.rapids.cudf.CSVOptions;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
@@ -107,7 +107,7 @@ public class BoosterTest {
List<ColumnBatch> tables = new LinkedList<>();
tables.add(batch);
DMatrix incrementalDMatrix = new DeviceQuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
DMatrix incrementalDMatrix = new QuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
//set watchList
HashMap<String, DMatrix> watches1 = new HashMap<>();
watches1.put("train", incrementalDMatrix);

View File

@@ -29,7 +29,7 @@ import org.junit.Test;
import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
@@ -117,7 +117,7 @@ public class DMatrixTest {
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
DMatrix dmat = new DeviceQuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));

View File

@@ -22,9 +22,9 @@ import ai.rapids.cudf.Table
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
class DeviceQuantileDMatrixSuite extends FunSuite {
class QuantileDMatrixSuite extends FunSuite {
test("DeviceQuantileDMatrix test") {
test("QuantileDMatrix test") {
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
@@ -51,8 +51,7 @@ class DeviceQuantileDMatrixSuite extends FunSuite {
val batches = new ArrayBuffer[CudfColumnBatch]()
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
val dmatrix = new DeviceQuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))