[breakinig] [jvm-packages] change DeviceQuantileDmatrix into QuantileDMatrix (#8461)
This commit is contained in:
@@ -20,6 +20,13 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jdata_iter, jobject jref_iter,
|
||||
char const *config, jlongArray jout) {
|
||||
API_BEGIN();
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
|
||||
@@ -379,7 +379,7 @@ int Next(DataIterHandle self) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jiter,
|
||||
jfloat jmissing,
|
||||
jint jmax_bin, jint jnthread,
|
||||
@@ -392,5 +392,21 @@ XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jdata_iter, jobject jref_iter,
|
||||
char const *config, jlongArray jout) {
|
||||
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
|
||||
DMatrixHandle result;
|
||||
|
||||
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
|
||||
if (jref_iter) {
|
||||
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
|
||||
}
|
||||
auto ret = XGQuantileDMatrixCreateFromCallback(
|
||||
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -34,7 +34,7 @@ import ai.rapids.cudf.CSVOptions;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
@@ -107,7 +107,7 @@ public class BoosterTest {
|
||||
|
||||
List<ColumnBatch> tables = new LinkedList<>();
|
||||
tables.add(batch);
|
||||
DMatrix incrementalDMatrix = new DeviceQuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
|
||||
DMatrix incrementalDMatrix = new QuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches1 = new HashMap<>();
|
||||
watches1.put("train", incrementalDMatrix);
|
||||
|
||||
@@ -29,7 +29,7 @@ import org.junit.Test;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
@@ -117,7 +117,7 @@ public class DMatrixTest {
|
||||
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
|
||||
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
|
||||
|
||||
DMatrix dmat = new DeviceQuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
|
||||
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
|
||||
|
||||
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
|
||||
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
|
||||
|
||||
@@ -22,9 +22,9 @@ import ai.rapids.cudf.Table
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
|
||||
class DeviceQuantileDMatrixSuite extends FunSuite {
|
||||
class QuantileDMatrixSuite extends FunSuite {
|
||||
|
||||
test("DeviceQuantileDMatrix test") {
|
||||
test("QuantileDMatrix test") {
|
||||
|
||||
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
|
||||
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
|
||||
@@ -51,8 +51,7 @@ class DeviceQuantileDMatrixSuite extends FunSuite {
|
||||
val batches = new ArrayBuffer[CudfColumnBatch]()
|
||||
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
|
||||
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
|
||||
val dmatrix = new DeviceQuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||
|
||||
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
|
||||
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
|
||||
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
|
||||
Reference in New Issue
Block a user