[breakinig] [jvm-packages] change DeviceQuantileDmatrix into QuantileDMatrix (#8461)
This commit is contained in:
@@ -34,7 +34,7 @@ import ai.rapids.cudf.CSVOptions;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
@@ -107,7 +107,7 @@ public class BoosterTest {
|
||||
|
||||
List<ColumnBatch> tables = new LinkedList<>();
|
||||
tables.add(batch);
|
||||
DMatrix incrementalDMatrix = new DeviceQuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
|
||||
DMatrix incrementalDMatrix = new QuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches1 = new HashMap<>();
|
||||
watches1.put("train", incrementalDMatrix);
|
||||
|
||||
@@ -29,7 +29,7 @@ import org.junit.Test;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
@@ -117,7 +117,7 @@ public class DMatrixTest {
|
||||
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
|
||||
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
|
||||
|
||||
DMatrix dmat = new DeviceQuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
|
||||
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
|
||||
|
||||
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
|
||||
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
|
||||
|
||||
@@ -22,9 +22,9 @@ import ai.rapids.cudf.Table
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
|
||||
class DeviceQuantileDMatrixSuite extends FunSuite {
|
||||
class QuantileDMatrixSuite extends FunSuite {
|
||||
|
||||
test("DeviceQuantileDMatrix test") {
|
||||
test("QuantileDMatrix test") {
|
||||
|
||||
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
|
||||
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
|
||||
@@ -51,8 +51,7 @@ class DeviceQuantileDMatrixSuite extends FunSuite {
|
||||
val batches = new ArrayBuffer[CudfColumnBatch]()
|
||||
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
|
||||
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
|
||||
val dmatrix = new DeviceQuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||
|
||||
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
|
||||
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
|
||||
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
|
||||
Reference in New Issue
Block a user