[jvm-packages] remove xgboost4j-gpu and rework cudf column (#10630)
This commit is contained in:
@@ -24,7 +24,7 @@
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-gpu_2.12</artifactId>
|
||||
<artifactId>xgboost4j_2.12</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
@@ -51,5 +51,17 @@
|
||||
<version>${spark.rapids.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${fasterxml.jackson.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
|
||||
import ai.rapids.cudf.ColumnVector;
|
||||
import ai.rapids.cudf.DType;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* CudfColumn is the CUDF column representing, providing the cuda array interface
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public class CudfColumn extends Column {
|
||||
private List<Long> shape = new ArrayList<>(); // row count
|
||||
private List<Object> data = new ArrayList<>(); // gpu data buffer address
|
||||
private String typestr;
|
||||
private int version = 1;
|
||||
private CudfColumn mask = null;
|
||||
|
||||
public CudfColumn(long shape, long data, String typestr, int version) {
|
||||
this.shape.add(shape);
|
||||
this.data.add(data);
|
||||
this.data.add(false);
|
||||
this.typestr = typestr;
|
||||
this.version = version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create CudfColumn according to ColumnVector
|
||||
*/
|
||||
public static CudfColumn from(ColumnVector cv) {
|
||||
BaseDeviceMemoryBuffer dataBuffer = cv.getData();
|
||||
assert dataBuffer != null;
|
||||
|
||||
DType dType = cv.getType();
|
||||
String typeStr = "";
|
||||
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
|
||||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
|
||||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
|
||||
dType == DType.TIMESTAMP_SECONDS) {
|
||||
typeStr = "<f" + dType.getSizeInBytes();
|
||||
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
|
||||
dType == DType.INT32 || dType == DType.INT64) {
|
||||
typeStr = "<i" + dType.getSizeInBytes();
|
||||
} else {
|
||||
// Unsupported type.
|
||||
throw new IllegalArgumentException("Unsupported data type: " + dType);
|
||||
}
|
||||
|
||||
CudfColumn data = new CudfColumn(cv.getRowCount(), dataBuffer.getAddress(), typeStr, 1);
|
||||
|
||||
BaseDeviceMemoryBuffer validBuffer = cv.getValid();
|
||||
if (validBuffer != null && cv.getNullCount() != 0) {
|
||||
CudfColumn mask = new CudfColumn(cv.getRowCount(), validBuffer.getAddress(), "<t1", 1);
|
||||
data.setMask(mask);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
public List<Long> getShape() {
|
||||
return shape;
|
||||
}
|
||||
|
||||
public List<Object> getData() {
|
||||
return data;
|
||||
}
|
||||
|
||||
public String getTypestr() {
|
||||
return typestr;
|
||||
}
|
||||
|
||||
public int getVersion() {
|
||||
return version;
|
||||
}
|
||||
|
||||
public CudfColumn getMask() {
|
||||
return mask;
|
||||
}
|
||||
|
||||
public void setMask(CudfColumn mask) {
|
||||
this.mask = mask;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toJson() {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
|
||||
try {
|
||||
List<CudfColumn> objects = new ArrayList<>(1);
|
||||
objects.add(this);
|
||||
return mapper.writeValueAsString(objects);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* CudfColumnBatch wraps multiple CudfColumns to provide the cuda
|
||||
* array interface json string for all columns.
|
||||
*/
|
||||
public class CudfColumnBatch extends ColumnBatch {
|
||||
@JsonIgnore
|
||||
private final Table featureTable;
|
||||
@JsonIgnore
|
||||
private final Table labelTable;
|
||||
@JsonIgnore
|
||||
private final Table weightTable;
|
||||
@JsonIgnore
|
||||
private final Table baseMarginTable;
|
||||
@JsonIgnore
|
||||
private final Table qidTable;
|
||||
|
||||
private List<CudfColumn> features;
|
||||
private List<CudfColumn> label;
|
||||
private List<CudfColumn> weight;
|
||||
private List<CudfColumn> baseMargin;
|
||||
private List<CudfColumn> qid;
|
||||
|
||||
public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
|
||||
Table baseMarginTable, Table qidTable) {
|
||||
this.featureTable = featureTable;
|
||||
this.labelTable = labelTable;
|
||||
this.weightTable = weightTable;
|
||||
this.baseMarginTable = baseMarginTable;
|
||||
this.qidTable = qidTable;
|
||||
|
||||
features = initializeCudfColumns(featureTable);
|
||||
if (labelTable != null) {
|
||||
assert labelTable.getNumberOfColumns() == 1;
|
||||
label = initializeCudfColumns(labelTable);
|
||||
}
|
||||
|
||||
if (weightTable != null) {
|
||||
assert weightTable.getNumberOfColumns() == 1;
|
||||
weight = initializeCudfColumns(weightTable);
|
||||
}
|
||||
|
||||
if (baseMarginTable != null) {
|
||||
baseMargin = initializeCudfColumns(baseMarginTable);
|
||||
}
|
||||
|
||||
if (qidTable != null) {
|
||||
qid = initializeCudfColumns(qidTable);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private List<CudfColumn> initializeCudfColumns(Table table) {
|
||||
assert table != null && table.getNumberOfColumns() > 0;
|
||||
|
||||
return IntStream.range(0, table.getNumberOfColumns())
|
||||
.mapToObj(table::getColumn)
|
||||
.map(CudfColumn::from)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<CudfColumn> getFeatures() {
|
||||
return features;
|
||||
}
|
||||
|
||||
public List<CudfColumn> getLabel() {
|
||||
return label;
|
||||
}
|
||||
|
||||
public List<CudfColumn> getWeight() {
|
||||
return weight;
|
||||
}
|
||||
|
||||
public List<CudfColumn> getBaseMargin() {
|
||||
return baseMargin;
|
||||
}
|
||||
|
||||
public List<CudfColumn> getQid() {
|
||||
return qid;
|
||||
}
|
||||
|
||||
public String toJson() {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
|
||||
try {
|
||||
return mapper.writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toFeaturesJson() {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
try {
|
||||
return mapper.writeValueAsString(features);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (featureTable != null) featureTable.close();
|
||||
if (labelTable != null) labelTable.close();
|
||||
if (weightTable != null) weightTable.close();
|
||||
if (baseMarginTable != null) baseMarginTable.close();
|
||||
if (qidTable != null) qidTable.close();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* QuantileDMatrix will only be used to train
|
||||
*/
|
||||
public class QuantileDMatrix extends DMatrix {
|
||||
/**
|
||||
* Create QuantileDMatrix from iterator based on the cuda array interface
|
||||
*
|
||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||
* @param missing the missing value
|
||||
* @param maxBin the max bin
|
||||
* @param nthread the parallelism
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public QuantileDMatrix(
|
||||
Iterator<ColumnBatch> iter,
|
||||
float missing,
|
||||
int maxBin,
|
||||
int nthread) throws XGBoostError {
|
||||
super(0);
|
||||
long[] out = new long[1];
|
||||
String conf = getConfig(missing, maxBin, nthread);
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
|
||||
iter, null, conf, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
|
||||
}
|
||||
|
||||
private String getConfig(float missing, int maxBin, int nthread) {
|
||||
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
|
||||
missing, maxBin, nthread);
|
||||
}
|
||||
}
|
||||
@@ -17,14 +17,12 @@
|
||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
|
||||
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.rdd.RDD
|
||||
@@ -325,7 +323,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
throw new RuntimeException("Something wrong for feature indices")
|
||||
}
|
||||
try {
|
||||
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null)
|
||||
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null, null)
|
||||
val dm = new DMatrix(cudfColumnBatch, missing, 1)
|
||||
if (dm == null) {
|
||||
Iterator.empty
|
||||
@@ -586,7 +584,8 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(indices.featureIds).asJava),
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(Seq(indices.labelId)).asJava),
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(weights).asJava),
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava));
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava),
|
||||
null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import ai.rapids.cudf.*;
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Tests the BoosterTest trained by DMatrix
|
||||
*
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public class BoosterTest {
|
||||
|
||||
@Test
|
||||
public void testBooster() throws XGBoostError {
|
||||
String trainingDataPath = getClass().getClassLoader()
|
||||
.getResource("veterans_lung_cancer.csv").getPath();
|
||||
Schema schema = Schema.builder()
|
||||
.column(DType.FLOAT32, "A")
|
||||
.column(DType.FLOAT32, "B")
|
||||
.column(DType.FLOAT32, "C")
|
||||
.column(DType.FLOAT32, "D")
|
||||
|
||||
.column(DType.FLOAT32, "E")
|
||||
.column(DType.FLOAT32, "F")
|
||||
.column(DType.FLOAT32, "G")
|
||||
.column(DType.FLOAT32, "H")
|
||||
|
||||
.column(DType.FLOAT32, "I")
|
||||
.column(DType.FLOAT32, "J")
|
||||
.column(DType.FLOAT32, "K")
|
||||
.column(DType.FLOAT32, "L")
|
||||
|
||||
.column(DType.FLOAT32, "label")
|
||||
.build();
|
||||
CSVOptions opts = CSVOptions.builder()
|
||||
.hasHeader().build();
|
||||
|
||||
int maxBin = 16;
|
||||
int round = 10;
|
||||
//set params
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 2);
|
||||
put("objective", "binary:logistic");
|
||||
put("num_round", round);
|
||||
put("num_workers", 1);
|
||||
put("tree_method", "hist");
|
||||
put("device", "cuda");
|
||||
put("max_bin", maxBin);
|
||||
}
|
||||
};
|
||||
|
||||
try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) {
|
||||
ColumnVector[] df = new ColumnVector[10];
|
||||
// exclude the first two columns, they are label bounds and contain inf.
|
||||
for (int i = 2; i < 12; ++i) {
|
||||
df[i - 2] = tmpTable.getColumn(i);
|
||||
}
|
||||
try (Table X = new Table(df);) {
|
||||
ColumnVector[] labels = new ColumnVector[1];
|
||||
labels[0] = tmpTable.getColumn(12);
|
||||
|
||||
try (Table y = new Table(labels);) {
|
||||
|
||||
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null, null);
|
||||
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
|
||||
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<>();
|
||||
|
||||
DMatrix dMatrix1 = new DMatrix(batch, Float.NaN, 1);
|
||||
dMatrix1.setLabel(labelColumn);
|
||||
watches.put("train", dMatrix1);
|
||||
Booster model1 = XGBoost.train(dMatrix1, paramMap, round, watches, null, null);
|
||||
|
||||
List<ColumnBatch> tables = new LinkedList<>();
|
||||
tables.add(batch);
|
||||
DMatrix incrementalDMatrix = new QuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches1 = new HashMap<>();
|
||||
watches1.put("train", incrementalDMatrix);
|
||||
Booster model2 = XGBoost.train(incrementalDMatrix, paramMap, round, watches1, null, null);
|
||||
|
||||
float[][] predicat1 = model1.predict(dMatrix1);
|
||||
float[][] predicat2 = model2.predict(dMatrix1);
|
||||
|
||||
for (int i = 0; i < tmpTable.getRowCount(); i++) {
|
||||
TestCase.assertTrue(predicat1[i][0] - predicat2[i][0] < 1e-6);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
/**
|
||||
* Test suite for DMatrix based on GPU
|
||||
*/
|
||||
public class DMatrixTest {
|
||||
|
||||
@Test
|
||||
public void testCreateFromArrayInterfaceColumns() {
|
||||
Float[] labelFloats = new Float[]{2f, 4f, 6f, 8f, 10f};
|
||||
Integer[] groups = new Integer[]{1, 1, 7, 7, 19, 26};
|
||||
int[] expectedGroup = new int[]{0, 2, 4, 5, 6};
|
||||
|
||||
Throwable ex = null;
|
||||
try (
|
||||
Table X = new Table.TestBuilder().column(1.f, null, 5.f, 7.f, 9.f).build();
|
||||
Table y = new Table.TestBuilder().column(labelFloats).build();
|
||||
Table w = new Table.TestBuilder().column(labelFloats).build();
|
||||
Table q = new Table.TestBuilder().column(groups).build();
|
||||
Table margin = new Table.TestBuilder().column(labelFloats).build();) {
|
||||
|
||||
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null, null);
|
||||
|
||||
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
|
||||
CudfColumn weightColumn = CudfColumn.from(w.getColumn(0));
|
||||
CudfColumn baseMarginColumn = CudfColumn.from(margin.getColumn(0));
|
||||
CudfColumn qidColumn = CudfColumn.from(q.getColumn(0));
|
||||
|
||||
DMatrix dMatrix = new DMatrix(cudfDataFrame, 0, 1);
|
||||
dMatrix.setLabel(labelColumn);
|
||||
dMatrix.setWeight(weightColumn);
|
||||
dMatrix.setBaseMargin(baseMarginColumn);
|
||||
dMatrix.setQueryId(qidColumn);
|
||||
|
||||
String[] featureNames = new String[]{"f1"};
|
||||
dMatrix.setFeatureNames(featureNames);
|
||||
String[] retFeatureNames = dMatrix.getFeatureNames();
|
||||
assertArrayEquals(featureNames, retFeatureNames);
|
||||
|
||||
String[] featureTypes = new String[]{"i"};
|
||||
dMatrix.setFeatureTypes(featureTypes);
|
||||
String[] retFeatureTypes = dMatrix.getFeatureTypes();
|
||||
assertArrayEquals(featureTypes, retFeatureTypes);
|
||||
|
||||
float[] anchor = convertFloatTofloat(labelFloats);
|
||||
float[] label = dMatrix.getLabel();
|
||||
float[] weight = dMatrix.getWeight();
|
||||
float[] baseMargin = dMatrix.getBaseMargin();
|
||||
int[] group = dMatrix.getGroup();
|
||||
|
||||
TestCase.assertTrue(Arrays.equals(anchor, label));
|
||||
TestCase.assertTrue(Arrays.equals(anchor, weight));
|
||||
TestCase.assertTrue(Arrays.equals(anchor, baseMargin));
|
||||
TestCase.assertTrue(Arrays.equals(expectedGroup, group));
|
||||
} catch (Throwable e) {
|
||||
ex = e;
|
||||
e.printStackTrace();
|
||||
}
|
||||
TestCase.assertNull(ex);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromColumnDataIterator() throws XGBoostError {
|
||||
|
||||
Float[] label1 = {25f, 21f, 22f, 20f, 24f};
|
||||
Float[] weight1 = {1.3f, 2.31f, 0.32f, 3.3f, 1.34f};
|
||||
Float[] baseMargin1 = {1.2f, 0.2f, 1.3f, 2.4f, 3.5f};
|
||||
Integer[] groups1 = new Integer[]{1, 1, 7, 7, 19, 26};
|
||||
|
||||
Float[] label2 = {9f, 5f, 4f, 10f, 12f};
|
||||
Float[] weight2 = {3.0f, 1.3f, 3.2f, 0.3f, 1.34f};
|
||||
Float[] baseMargin2 = {0.2f, 2.5f, 3.1f, 4.4f, 2.2f};
|
||||
Integer[] groups2 = new Integer[]{30, 30, 30, 40, 40};
|
||||
|
||||
int[] expectedGroup = new int[]{0, 2, 4, 5, 6, 9, 11};
|
||||
|
||||
try (
|
||||
Table X_0 = new Table.TestBuilder()
|
||||
.column(1.2f, null, 5.2f, 7.2f, 9.2f)
|
||||
.column(0.2f, 0.4f, 0.6f, 2.6f, 0.10f)
|
||||
.build();
|
||||
Table y_0 = new Table.TestBuilder().column(label1).build();
|
||||
Table w_0 = new Table.TestBuilder().column(weight1).build();
|
||||
Table m_0 = new Table.TestBuilder().column(baseMargin1).build();
|
||||
Table q_0 = new Table.TestBuilder().column(groups1).build();
|
||||
|
||||
Table X_1 = new Table.TestBuilder().column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f)
|
||||
.column(1.2f, 1.4f, null, 12.6f, 10.10f).build();
|
||||
Table y_1 = new Table.TestBuilder().column(label2).build();
|
||||
Table w_1 = new Table.TestBuilder().column(weight2).build();
|
||||
Table m_1 = new Table.TestBuilder().column(baseMargin2).build();) {
|
||||
Table q_1 = new Table.TestBuilder().column(groups2).build();
|
||||
|
||||
List<ColumnBatch> tables = new LinkedList<>();
|
||||
|
||||
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0));
|
||||
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1, q_1));
|
||||
|
||||
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
|
||||
|
||||
float[] anchorLabel = convertFloatTofloat(label1, label2);
|
||||
float[] anchorWeight = convertFloatTofloat(weight1, weight2);
|
||||
float[] anchorBaseMargin = convertFloatTofloat(baseMargin1, baseMargin2);
|
||||
|
||||
TestCase.assertTrue(Arrays.equals(anchorLabel, dmat.getLabel()));
|
||||
TestCase.assertTrue(Arrays.equals(anchorWeight, dmat.getWeight()));
|
||||
TestCase.assertTrue(Arrays.equals(anchorBaseMargin, dmat.getBaseMargin()));
|
||||
TestCase.assertTrue(Arrays.equals(expectedGroup, dmat.getGroup()));
|
||||
}
|
||||
}
|
||||
|
||||
private float[] convertFloatTofloat(Float[]... datas) {
|
||||
int totalLength = 0;
|
||||
for (Float[] data : datas) {
|
||||
totalLength += data.length;
|
||||
}
|
||||
float[] floatArray = new float[totalLength];
|
||||
int index = 0;
|
||||
for (Float[] data : datas) {
|
||||
for (int i = 0; i < data.length; i++) {
|
||||
floatArray[i + index] = data[i];
|
||||
}
|
||||
index += data.length;
|
||||
}
|
||||
return floatArray;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ai.rapids.cudf.Table
|
||||
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
class QuantileDMatrixSuite extends AnyFunSuite {
|
||||
|
||||
test("QuantileDMatrix test") {
|
||||
|
||||
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
|
||||
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
|
||||
val baseMargin1 = Array[java.lang.Float](1.2f, 0.2f, 1.3f, 2.4f, 3.5f)
|
||||
val group1 = Array[java.lang.Integer](1, 1, 7, 7, 19, 26)
|
||||
|
||||
val label2 = Array[java.lang.Float](9f, 5f, 4f, 10f, 12f)
|
||||
val weight2 = Array[java.lang.Float](3.0f, 1.3f, 3.2f, 0.3f, 1.34f)
|
||||
val baseMargin2 = Array[java.lang.Float](0.2f, 2.5f, 3.1f, 4.4f, 2.2f)
|
||||
val group2 = Array[java.lang.Integer](30, 30, 30, 40, 40)
|
||||
|
||||
val expectedGroup = Array(0, 2, 4, 5, 6, 9, 11)
|
||||
|
||||
withResource(new Table.TestBuilder()
|
||||
.column(1.2f, null.asInstanceOf[java.lang.Float], 5.2f, 7.2f, 9.2f)
|
||||
.column(0.2f, 0.4f, 0.6f, 2.6f, 0.10f.asInstanceOf[java.lang.Float])
|
||||
.build) { X_0 =>
|
||||
withResource(new Table.TestBuilder().column(label1: _*).build) { y_0 =>
|
||||
withResource(new Table.TestBuilder().column(weight1: _*).build) { w_0 =>
|
||||
withResource(new Table.TestBuilder().column(baseMargin1: _*).build) { m_0 =>
|
||||
withResource(new Table.TestBuilder().column(group1: _*).build) { q_0 =>
|
||||
withResource(new Table.TestBuilder()
|
||||
.column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float])
|
||||
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build) {
|
||||
X_1 =>
|
||||
withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 =>
|
||||
withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 =>
|
||||
withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 =>
|
||||
withResource(new Table.TestBuilder().column(group2: _*).build) { q_2 =>
|
||||
val batches = new ArrayBuffer[CudfColumnBatch]()
|
||||
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0)
|
||||
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1, q_2)
|
||||
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
|
||||
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
|
||||
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
|
||||
assert(dmatrix.getGroup().sameElements(expectedGroup))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Executes the provided code block and then closes the resource */
|
||||
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
||||
try {
|
||||
block(r)
|
||||
} finally {
|
||||
r.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user