[jvm-packages] remove xgboost4j-gpu and rework cudf column (#10630)

This commit is contained in:
Bobby Wang 2024-07-25 15:31:16 +08:00 committed by GitHub
parent fcae6301ec
commit d5834b68c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 509 additions and 632 deletions

View File

@ -33,7 +33,6 @@ def main(args):
for artifact in [
"xgboost-jvm",
"xgboost4j",
"xgboost4j-gpu",
"xgboost4j-spark",
"xgboost4j-spark-gpu",
"xgboost4j-flink",

View File

@ -2,11 +2,11 @@ find_package(JNI REQUIRED)
list(APPEND JVM_SOURCES
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp)
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cpp)
if(USE_CUDA)
list(APPEND JVM_SOURCES
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu)
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu)
endif()
add_library(xgboost4j SHARED ${JVM_SOURCES} ${XGBOOST_OBJ_SOURCES})

View File

@ -131,14 +131,6 @@ def native_build(args):
run("cmake .. " + " ".join(args))
run("cmake --build . --config Release" + maybe_parallel_build)
with cd("demo/CLI/regression"):
run(f'"{sys.executable}" mapfeat.py')
run(f'"{sys.executable}" mknfold.py machine.txt 1')
xgboost4j = "xgboost4j-gpu" if cli_args.use_cuda == "ON" else "xgboost4j"
xgboost4j_spark = (
"xgboost4j-spark-gpu" if cli_args.use_cuda == "ON" else "xgboost4j-spark"
)
print("copying native library")
library_name, os_folder = {
@ -155,26 +147,34 @@ def native_build(args):
"arm64": "aarch64", # on macOS & Windows ARM 64-bit
"aarch64": "aarch64",
}[platform.machine().lower()]
output_folder = "{}/src/main/resources/lib/{}/{}".format(
xgboost4j, os_folder, arch_folder
output_folder = "xgboost4j/src/main/resources/lib/{}/{}".format(
os_folder, arch_folder
)
maybe_makedirs(output_folder)
cp("../lib/" + library_name, output_folder)
print("copying train/test files")
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
# for xgboost4j
maybe_makedirs("xgboost4j/src/test/resources")
for file in glob.glob("../demo/data/agaricus.*"):
cp(file, "xgboost4j/src/test/resources")
# for xgboost4j-spark
maybe_makedirs("xgboost4j-spark/src/test/resources")
with cd("../demo/CLI/regression"):
run(f'"{sys.executable}" mapfeat.py')
run(f'"{sys.executable}" mknfold.py machine.txt 1')
for file in glob.glob("../demo/CLI/regression/machine.txt.t*"):
cp(file, "{}/src/test/resources".format(xgboost4j_spark))
cp(file, "xgboost4j-spark/src/test/resources")
for file in glob.glob("../demo/data/agaricus.*"):
cp(file, "{}/src/test/resources".format(xgboost4j_spark))
cp(file, "xgboost4j-spark/src/test/resources")
maybe_makedirs("{}/src/test/resources".format(xgboost4j))
for file in glob.glob("../demo/data/agaricus.*"):
cp(file, "{}/src/test/resources".format(xgboost4j))
# for xgboost4j-spark-gpu
if cli_args.use_cuda == "ON":
maybe_makedirs("xgboost4j-spark-gpu/src/test/resources")
for file in glob.glob("../demo/data/veterans_lung_cancer.csv"):
cp(file, "xgboost4j-spark-gpu/src/test/resources")
if __name__ == "__main__":

View File

@ -37,7 +37,7 @@
<junit.version>4.13.2</junit.version>
<spark.version>3.5.1</spark.version>
<spark.version.gpu>3.5.1</spark.version.gpu>
<fasterxml.jackson.version>2.17.2</fasterxml.jackson.version>
<fasterxml.jackson.version>2.15.0</fasterxml.jackson.version>
<scala.version>2.12.18</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.4.0</hadoop.version>
@ -105,7 +105,7 @@
<use.cuda>ON</use.cuda>
</properties>
<modules>
<module>xgboost4j-gpu</module>
<module>xgboost4j</module>
<module>xgboost4j-spark-gpu</module>
</modules>
</profile>
@ -117,7 +117,6 @@
<module>xgboost4j-example</module>
<module>xgboost4j-spark</module>
<module>xgboost4j-flink</module>
<module>xgboost4j-gpu</module>
<module>xgboost4j-spark-gpu</module>
</modules>
<build>
@ -243,7 +242,6 @@
<module>xgboost4j-example</module>
<module>xgboost4j-spark</module>
<module>xgboost4j-flink</module>
<module>xgboost4j-gpu</module>
<module>xgboost4j-spark-gpu</module>
</modules>
<build>

View File

@ -1,140 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>2.2.0-SNAPSHOT</version>
</parent>
<artifactId>xgboost4j-gpu_2.12</artifactId>
<name>xgboost4j-gpu</name>
<version>2.2.0-SNAPSHOT</version>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-collection-compat_${scala.binary.version}</artifactId>
<version>${scala-collection-compat.version}</version>
</dependency>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>cudf</artifactId>
<version>${cudf.version}</version>
<classifier>${cudf.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>${scalatest.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.14.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>3.7.0</version>
<configuration>
<show>protected</show>
<nohelp>true</nohelp>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<skipAssembly>false</skipAssembly>
</configuration>
</plugin>
<plugin>
<artifactId>exec-maven-plugin</artifactId>
<groupId>org.codehaus.mojo</groupId>
<version>3.3.0</version>
<executions>
<execution>
<id>native</id>
<phase>generate-sources</phase>
<goals>
<goal>exec</goal>
</goals>
<configuration>
<executable>python</executable>
<arguments>
<argument>create_jni.py</argument>
<argument>--log-capi-invocation</argument>
<argument>${log.capi.invocation}</argument>
<argument>--use-cuda</argument>
<argument>${use.cuda}</argument>
</arguments>
<workingDirectory>${user.dir}</workingDirectory>
<skip>${skip.native.build}</skip>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.4.1</version>
<executions>
<execution>
<goals>
<goal>test-jar</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-resources-plugin</artifactId>
<version>3.3.1</version>
<configuration>
<nonFilteredFileExtensions>
<nonFilteredFileExtension>dll</nonFilteredFileExtension>
<nonFilteredFileExtension>dylib</nonFilteredFileExtension>
<nonFilteredFileExtension>so</nonFilteredFileExtension>
</nonFilteredFileExtensions>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -1,110 +0,0 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.BufferType;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ml.dmlc.xgboost4j.java.Column;
/**
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
* It will be used to generate the cuda array interface.
*/
public class CudfColumn extends Column {
private final long dataPtr; // gpu data buffer address
private final long shape; // row count
private final long validPtr; // gpu valid buffer address
private final int typeSize; // type size in bytes
private final String typeStr; // follow array interface spec
private final long nullCount; // null count
private String arrayInterface = null; // the cuda array interface
public static CudfColumn from(ColumnVector cv) {
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
long validPtr = 0;
if (validBuffer != null) {
validPtr = validBuffer.getAddress();
}
DType dType = cv.getType();
String typeStr = "";
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
dType == DType.TIMESTAMP_SECONDS) {
typeStr = "<f" + dType.getSizeInBytes();
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
dType == DType.INT32 || dType == DType.INT64) {
typeStr = "<i" + dType.getSizeInBytes();
} else {
// Unsupported type.
throw new IllegalArgumentException("Unsupported data type: " + dType);
}
return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
dType.getSizeInBytes(), typeStr, cv.getNullCount());
}
private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
long nullCount) {
this.dataPtr = dataPtr;
this.shape = shape;
this.validPtr = validPtr;
this.typeSize = typeSize;
this.typeStr = typeStr;
this.nullCount = nullCount;
}
@Override
public String getArrayInterfaceJson() {
// There is no race-condition
if (arrayInterface == null) {
arrayInterface = CudfUtils.buildArrayInterface(this);
}
return arrayInterface;
}
public long getDataPtr() {
return dataPtr;
}
public long getShape() {
return shape;
}
public long getValidPtr() {
return validPtr;
}
public int getTypeSize() {
return typeSize;
}
public String getTypeStr() {
return typeStr;
}
public long getNullCount() {
return nullCount;
}
}

View File

@ -1,88 +0,0 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.util.stream.IntStream;
import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.ColumnBatch;
/**
* Class to wrap CUDF Table to generate the cuda array interface.
*/
public class CudfColumnBatch extends ColumnBatch {
private final Table feature;
private final Table label;
private final Table weight;
private final Table baseMargin;
public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
this.feature = feature;
this.label = labels;
this.weight = weights;
this.baseMargin = baseMargins;
}
@Override
public String getFeatureArrayInterface() {
return getArrayInterface(this.feature);
}
@Override
public String getLabelsArrayInterface() {
return getArrayInterface(this.label);
}
@Override
public String getWeightsArrayInterface() {
return getArrayInterface(this.weight);
}
@Override
public String getBaseMarginsArrayInterface() {
return getArrayInterface(this.baseMargin);
}
@Override
public void close() {
if (feature != null) feature.close();
if (label != null) label.close();
if (weight != null) weight.close();
if (baseMargin != null) baseMargin.close();
}
private String getArrayInterface(Table table) {
if (table == null || table.getNumberOfColumns() == 0) {
return "";
}
return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
}
private CudfColumn[] getAsCudfColumn(Table table) {
if (table == null || table.getNumberOfColumns() == 0) {
// This will never happen.
return new CudfColumn[]{};
}
return IntStream.range(0, table.getNumberOfColumns())
.mapToObj((i) -> table.getColumn(i))
.map(CudfColumn::from)
.toArray(CudfColumn[]::new);
}
}

View File

@ -1,98 +0,0 @@
/*
Copyright (c) 2021-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.util.ArrayList;
/**
* Cudf utilities to build cuda array interface against {@link CudfColumn}
*/
class CudfUtils {
/**
* Build the cuda array interface based on CudfColumn(s)
* @param cudfColumns the CudfColumn(s) to be built
* @return the json format of cuda array interface
*/
public static String buildArrayInterface(CudfColumn... cudfColumns) {
return new Builder().add(cudfColumns).build();
}
// Helper class to build array interface string
private static class Builder {
private ArrayList<String> colArrayInterfaces = new ArrayList<String>();
private Builder add(CudfColumn... columns) {
if (columns == null || columns.length <= 0) {
throw new IllegalArgumentException("At least one ColumnData is required.");
}
for (CudfColumn cd : columns) {
colArrayInterfaces.add(buildColumnObject(cd));
}
return this;
}
private String build() {
StringBuilder builder = new StringBuilder();
builder.append("[");
for (int i = 0; i < colArrayInterfaces.size(); i++) {
builder.append(colArrayInterfaces.get(i));
if (i != colArrayInterfaces.size() - 1) {
builder.append(",");
}
}
builder.append("]");
return builder.toString();
}
/** build the whole column information including data and valid info */
private String buildColumnObject(CudfColumn column) {
if (column.getDataPtr() == 0) {
throw new IllegalArgumentException("Empty column data is NOT accepted!");
}
if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) {
throw new IllegalArgumentException("Empty type string is NOT accepted!");
}
StringBuilder builder = new StringBuilder();
String colData = buildMetaObject(column.getDataPtr(), column.getShape(),
column.getTypeStr());
builder.append("{");
builder.append(colData);
if (column.getValidPtr() != 0 && column.getNullCount() != 0) {
String validString = buildMetaObject(column.getValidPtr(), column.getShape(), "<t1");
builder.append(",\"mask\":");
builder.append("{");
builder.append(validString);
builder.append("}");
}
builder.append("}");
return builder.toString();
}
/** build the base information of a column */
private String buildMetaObject(long ptr, long shape, final String typeStr) {
StringBuilder builder = new StringBuilder();
builder.append("\"shape\":[" + shape + "],");
builder.append("\"data\":[" + ptr + "," + "false" + "],");
builder.append("\"typestr\":\"" + typeStr + "\",");
builder.append("\"version\":" + 1);
return builder.toString();
}
}
}

View File

@ -1 +0,0 @@
../../../../../../../xgboost4j/src/main/java/ml/dmlc/xgboost4j/java

View File

@ -1 +0,0 @@
../../../../xgboost4j/src/main/resources/xgboost4j-version.properties

View File

@ -1 +0,0 @@
../../../xgboost4j/src/main/scala/

View File

@ -24,7 +24,7 @@
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-gpu_2.12</artifactId>
<artifactId>xgboost4j_2.12</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@ -51,5 +51,17 @@
<version>${spark.rapids.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,117 @@
/*
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.util.ArrayList;
import java.util.List;
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* CudfColumn is the CUDF column representing, providing the cuda array interface
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class CudfColumn extends Column {
private List<Long> shape = new ArrayList<>(); // row count
private List<Object> data = new ArrayList<>(); // gpu data buffer address
private String typestr;
private int version = 1;
private CudfColumn mask = null;
public CudfColumn(long shape, long data, String typestr, int version) {
this.shape.add(shape);
this.data.add(data);
this.data.add(false);
this.typestr = typestr;
this.version = version;
}
/**
* Create CudfColumn according to ColumnVector
*/
public static CudfColumn from(ColumnVector cv) {
BaseDeviceMemoryBuffer dataBuffer = cv.getData();
assert dataBuffer != null;
DType dType = cv.getType();
String typeStr = "";
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
dType == DType.TIMESTAMP_SECONDS) {
typeStr = "<f" + dType.getSizeInBytes();
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
dType == DType.INT32 || dType == DType.INT64) {
typeStr = "<i" + dType.getSizeInBytes();
} else {
// Unsupported type.
throw new IllegalArgumentException("Unsupported data type: " + dType);
}
CudfColumn data = new CudfColumn(cv.getRowCount(), dataBuffer.getAddress(), typeStr, 1);
BaseDeviceMemoryBuffer validBuffer = cv.getValid();
if (validBuffer != null && cv.getNullCount() != 0) {
CudfColumn mask = new CudfColumn(cv.getRowCount(), validBuffer.getAddress(), "<t1", 1);
data.setMask(mask);
}
return data;
}
public List<Long> getShape() {
return shape;
}
public List<Object> getData() {
return data;
}
public String getTypestr() {
return typestr;
}
public int getVersion() {
return version;
}
public CudfColumn getMask() {
return mask;
}
public void setMask(CudfColumn mask) {
this.mask = mask;
}
@Override
public String toJson() {
ObjectMapper mapper = new ObjectMapper();
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
try {
List<CudfColumn> objects = new ArrayList<>(1);
objects.add(this);
return mapper.writeValueAsString(objects);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,137 @@
/*
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import ai.rapids.cudf.Table;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* CudfColumnBatch wraps multiple CudfColumns to provide the cuda
* array interface json string for all columns.
*/
public class CudfColumnBatch extends ColumnBatch {
@JsonIgnore
private final Table featureTable;
@JsonIgnore
private final Table labelTable;
@JsonIgnore
private final Table weightTable;
@JsonIgnore
private final Table baseMarginTable;
@JsonIgnore
private final Table qidTable;
private List<CudfColumn> features;
private List<CudfColumn> label;
private List<CudfColumn> weight;
private List<CudfColumn> baseMargin;
private List<CudfColumn> qid;
public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
Table baseMarginTable, Table qidTable) {
this.featureTable = featureTable;
this.labelTable = labelTable;
this.weightTable = weightTable;
this.baseMarginTable = baseMarginTable;
this.qidTable = qidTable;
features = initializeCudfColumns(featureTable);
if (labelTable != null) {
assert labelTable.getNumberOfColumns() == 1;
label = initializeCudfColumns(labelTable);
}
if (weightTable != null) {
assert weightTable.getNumberOfColumns() == 1;
weight = initializeCudfColumns(weightTable);
}
if (baseMarginTable != null) {
baseMargin = initializeCudfColumns(baseMarginTable);
}
if (qidTable != null) {
qid = initializeCudfColumns(qidTable);
}
}
private List<CudfColumn> initializeCudfColumns(Table table) {
assert table != null && table.getNumberOfColumns() > 0;
return IntStream.range(0, table.getNumberOfColumns())
.mapToObj(table::getColumn)
.map(CudfColumn::from)
.collect(Collectors.toList());
}
public List<CudfColumn> getFeatures() {
return features;
}
public List<CudfColumn> getLabel() {
return label;
}
public List<CudfColumn> getWeight() {
return weight;
}
public List<CudfColumn> getBaseMargin() {
return baseMargin;
}
public List<CudfColumn> getQid() {
return qid;
}
public String toJson() {
ObjectMapper mapper = new ObjectMapper();
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
try {
return mapper.writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
@Override
public String toFeaturesJson() {
ObjectMapper mapper = new ObjectMapper();
try {
return mapper.writeValueAsString(features);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
@Override
public void close() {
if (featureTable != null) featureTable.close();
if (labelTable != null) labelTable.close();
if (weightTable != null) weightTable.close();
if (baseMarginTable != null) baseMarginTable.close();
if (qidTable != null) qidTable.close();
}
}

View File

@ -0,0 +1,90 @@
/*
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
String conf = getConfig(missing, maxBin, nthread);
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
iter, null, conf, out));
handle = out[0];
}
@Override
public void setLabel(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setLabel(float[] labels) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(float[] weights) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
}
private String getConfig(float missing, int maxBin, int nthread) {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
missing, maxBin, nthread);
}
}

View File

@ -17,14 +17,12 @@
package ml.dmlc.xgboost4j.scala.rapids.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix}
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
import org.apache.commons.logging.LogFactory
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
@ -325,7 +323,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
throw new RuntimeException("Something wrong for feature indices")
}
try {
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null)
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null, null)
val dm = new DMatrix(cudfColumnBatch, missing, 1)
if (dm == null) {
Iterator.empty
@ -586,7 +584,8 @@ object GpuPreXGBoost extends PreXGBoostProvider {
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(indices.featureIds).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(Seq(indices.labelId)).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(weights).asJava),
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava));
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava),
null);
}
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.util.HashMap;
@ -22,31 +22,21 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import ai.rapids.cudf.*;
import junit.framework.TestCase;
import org.junit.Test;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Schema;
import ai.rapids.cudf.Table;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.CSVOptions;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* Tests the BoosterTest trained by DMatrix
*
* @throws XGBoostError
*/
public class BoosterTest {
@Test
public void testBooster() throws XGBoostError {
String trainingDataPath = "../../demo/data/veterans_lung_cancer.csv";
String trainingDataPath = getClass().getClassLoader()
.getResource("veterans_lung_cancer.csv").getPath();
Schema schema = Schema.builder()
.column(DType.FLOAT32, "A")
.column(DType.FLOAT32, "B")
@ -78,7 +68,7 @@ public class BoosterTest {
put("num_round", round);
put("num_workers", 1);
put("tree_method", "hist");
put("device", "cuda");
put("device", "cuda");
put("max_bin", maxBin);
}
};
@ -95,7 +85,7 @@ public class BoosterTest {
try (Table y = new Table(labels);) {
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null);
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null, null);
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
//set watchList

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2022 by Contributors
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,24 +14,15 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
package ml.dmlc.xgboost4j.java;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import junit.framework.TestCase;
import com.google.common.primitives.Floats;
import org.apache.commons.lang3.ArrayUtils;
import org.junit.Test;
import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
import junit.framework.TestCase;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
@ -43,24 +34,29 @@ public class DMatrixTest {
@Test
public void testCreateFromArrayInterfaceColumns() {
Float[] labelFloats = new Float[]{2f, 4f, 6f, 8f, 10f};
Integer[] groups = new Integer[]{1, 1, 7, 7, 19, 26};
int[] expectedGroup = new int[]{0, 2, 4, 5, 6};
Throwable ex = null;
try (
Table X = new Table.TestBuilder().column(1.f, null, 5.f, 7.f, 9.f).build();
Table y = new Table.TestBuilder().column(labelFloats).build();
Table w = new Table.TestBuilder().column(labelFloats).build();
Table q = new Table.TestBuilder().column(groups).build();
Table margin = new Table.TestBuilder().column(labelFloats).build();) {
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null);
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null, null);
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
CudfColumn weightColumn = CudfColumn.from(w.getColumn(0));
CudfColumn baseMarginColumn = CudfColumn.from(margin.getColumn(0));
CudfColumn qidColumn = CudfColumn.from(q.getColumn(0));
DMatrix dMatrix = new DMatrix(cudfDataFrame, 0, 1);
dMatrix.setLabel(labelColumn);
dMatrix.setWeight(weightColumn);
dMatrix.setBaseMargin(baseMarginColumn);
dMatrix.setQueryId(qidColumn);
String[] featureNames = new String[]{"f1"};
dMatrix.setFeatureNames(featureNames);
@ -76,10 +72,12 @@ public class DMatrixTest {
float[] label = dMatrix.getLabel();
float[] weight = dMatrix.getWeight();
float[] baseMargin = dMatrix.getBaseMargin();
int[] group = dMatrix.getGroup();
TestCase.assertTrue(Arrays.equals(anchor, label));
TestCase.assertTrue(Arrays.equals(anchor, weight));
TestCase.assertTrue(Arrays.equals(anchor, baseMargin));
TestCase.assertTrue(Arrays.equals(expectedGroup, group));
} catch (Throwable e) {
ex = e;
e.printStackTrace();
@ -93,10 +91,14 @@ public class DMatrixTest {
Float[] label1 = {25f, 21f, 22f, 20f, 24f};
Float[] weight1 = {1.3f, 2.31f, 0.32f, 3.3f, 1.34f};
Float[] baseMargin1 = {1.2f, 0.2f, 1.3f, 2.4f, 3.5f};
Integer[] groups1 = new Integer[]{1, 1, 7, 7, 19, 26};
Float[] label2 = {9f, 5f, 4f, 10f, 12f};
Float[] weight2 = {3.0f, 1.3f, 3.2f, 0.3f, 1.34f};
Float[] baseMargin2 = {0.2f, 2.5f, 3.1f, 4.4f, 2.2f};
Integer[] groups2 = new Integer[]{30, 30, 30, 40, 40};
int[] expectedGroup = new int[]{0, 2, 4, 5, 6, 9, 11};
try (
Table X_0 = new Table.TestBuilder()
@ -106,30 +108,47 @@ public class DMatrixTest {
Table y_0 = new Table.TestBuilder().column(label1).build();
Table w_0 = new Table.TestBuilder().column(weight1).build();
Table m_0 = new Table.TestBuilder().column(baseMargin1).build();
Table q_0 = new Table.TestBuilder().column(groups1).build();
Table X_1 = new Table.TestBuilder().column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f)
.column(1.2f, 1.4f, null, 12.6f, 10.10f).build();
Table y_1 = new Table.TestBuilder().column(label2).build();
Table w_1 = new Table.TestBuilder().column(weight2).build();
Table m_1 = new Table.TestBuilder().column(baseMargin2).build();) {
Table q_1 = new Table.TestBuilder().column(groups2).build();
List<ColumnBatch> tables = new LinkedList<>();
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0));
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1, q_1));
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
float[] anchorBaseMargin = convertFloatTofloat((Float[]) ArrayUtils.addAll(baseMargin1, baseMargin2));
float[] anchorLabel = convertFloatTofloat(label1, label2);
float[] anchorWeight = convertFloatTofloat(weight1, weight2);
float[] anchorBaseMargin = convertFloatTofloat(baseMargin1, baseMargin2);
TestCase.assertTrue(Arrays.equals(anchorLabel, dmat.getLabel()));
TestCase.assertTrue(Arrays.equals(anchorWeight, dmat.getWeight()));
TestCase.assertTrue(Arrays.equals(anchorBaseMargin, dmat.getBaseMargin()));
TestCase.assertTrue(Arrays.equals(expectedGroup, dmat.getGroup()));
}
}
private float[] convertFloatTofloat(Float[] in) {
return Floats.toArray(Arrays.asList(in));
private float[] convertFloatTofloat(Float[]... datas) {
int totalLength = 0;
for (Float[] data : datas) {
totalLength += data.length;
}
float[] floatArray = new float[totalLength];
int index = 0;
for (Float[] data : datas) {
for (int i = 0; i < data.length; i++) {
floatArray[i + index] = data[i];
}
index += data.length;
}
return floatArray;
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,11 +16,11 @@
package ml.dmlc.xgboost4j.scala
import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.Table
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
import scala.collection.mutable.ArrayBuffer
class QuantileDMatrixSuite extends AnyFunSuite {
@ -29,10 +29,14 @@ class QuantileDMatrixSuite extends AnyFunSuite {
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
val baseMargin1 = Array[java.lang.Float](1.2f, 0.2f, 1.3f, 2.4f, 3.5f)
val group1 = Array[java.lang.Integer](1, 1, 7, 7, 19, 26)
val label2 = Array[java.lang.Float](9f, 5f, 4f, 10f, 12f)
val weight2 = Array[java.lang.Float](3.0f, 1.3f, 3.2f, 0.3f, 1.34f)
val baseMargin2 = Array[java.lang.Float](0.2f, 2.5f, 3.1f, 4.4f, 2.2f)
val group2 = Array[java.lang.Integer](30, 30, 30, 40, 40)
val expectedGroup = Array(0, 2, 4, 5, 6, 9, 11)
withResource(new Table.TestBuilder()
.column(1.2f, null.asInstanceOf[java.lang.Float], 5.2f, 7.2f, 9.2f)
@ -41,22 +45,27 @@ class QuantileDMatrixSuite extends AnyFunSuite {
withResource(new Table.TestBuilder().column(label1: _*).build) { y_0 =>
withResource(new Table.TestBuilder().column(weight1: _*).build) { w_0 =>
withResource(new Table.TestBuilder().column(baseMargin1: _*).build) { m_0 =>
withResource(new Table.TestBuilder()
.column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float])
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build)
{ X_1 =>
withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 =>
withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 =>
withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 =>
val batches = new ArrayBuffer[CudfColumnBatch]()
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
withResource(new Table.TestBuilder().column(group1: _*).build) { q_0 =>
withResource(new Table.TestBuilder()
.column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float])
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build) {
X_1 =>
withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 =>
withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 =>
withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 =>
withResource(new Table.TestBuilder().column(group2: _*).build) { q_2 =>
val batches = new ArrayBuffer[CudfColumnBatch]()
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0)
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1, q_2)
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
assert(dmatrix.getGroup().sameElements(expectedGroup))
}
}
}
}
}
}
}
}
@ -73,6 +82,4 @@ class QuantileDMatrixSuite extends AnyFunSuite {
r.close()
}
}
}

View File

@ -96,6 +96,8 @@
<argument>create_jni.py</argument>
<argument>--log-capi-invocation</argument>
<argument>${log.capi.invocation}</argument>
<argument>--use-cuda</argument>
<argument>${use.cuda}</argument>
</arguments>
<workingDirectory>${user.dir}</workingDirectory>
<skip>${skip.native.build}</skip>

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,24 +17,17 @@
package ml.dmlc.xgboost4j.java;
/**
* The abstracted XGBoost Column to get the cuda array interface which is used to
* set the information for DMatrix.
*
* This Column abstraction provides an array interface JSON string, which is
* used to reconstruct columnar data within the XGBoost library.
*/
public abstract class Column implements AutoCloseable {
/**
* Get the cuda array interface json string for the Column which can be representing
* weight, label, base margin column.
*
* This API will be called by
* {@link DMatrix#setLabel(Column)}
* {@link DMatrix#setWeight(Column)}
* {@link DMatrix#setBaseMargin(Column)}
* Return array interface json string for this Column
*/
public abstract String getArrayInterfaceJson();
public abstract String toJson();
@Override
public void close() throws Exception {}
public void close() throws Exception {
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021 by Contributors
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,78 +16,12 @@
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
* For example, the cuDF dataframe which employs apache arrow specification.
* This class wraps multiple Column and provides the array interface json
* for all columns.
*/
public abstract class ColumnBatch implements AutoCloseable {
/**
* Get the cuda array interface json string for the whole ColumnBatch including
* the must-have feature, label columns and the optional weight, base margin columns.
*
* This function is be called by native code during iteration and can be made as private
* method. We keep it as public simply to silent the linter.
*/
public final String getArrayInterfaceJson() {
StringBuilder builder = new StringBuilder();
builder.append("{");
String featureStr = this.getFeatureArrayInterface();
if (featureStr == null || featureStr.isEmpty()) {
throw new RuntimeException("Feature array interface must not be empty");
} else {
builder.append("\"features_str\":" + featureStr);
}
String labelStr = this.getLabelsArrayInterface();
if (labelStr == null || labelStr.isEmpty()) {
throw new RuntimeException("Label array interface must not be empty");
} else {
builder.append(",\"label_str\":" + labelStr);
}
String weightStr = getWeightsArrayInterface();
if (weightStr != null && ! weightStr.isEmpty()) {
builder.append(",\"weight_str\":" + weightStr);
}
String baseMarginStr = getBaseMarginsArrayInterface();
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
builder.append(",\"basemargin_str\":" + baseMarginStr);
}
builder.append("}");
return builder.toString();
}
/**
* Get the cuda array interface of the feature columns.
* The returned value must not be null or empty
*/
public abstract String getFeatureArrayInterface();
/**
* Get the cuda array interface of the label columns.
* The returned value must not be null or empty if we're creating
* {@link QuantileDMatrix#QuantileDMatrix(Iterator, float, int, int)}
*/
public abstract String getLabelsArrayInterface();
/**
* Get the cuda array interface of the weight columns.
* The returned value can be null or empty
*/
public abstract String getWeightsArrayInterface();
/**
* Get the cuda array interface of the base margin columns.
* The returned value can be null or empty
*/
public abstract String getBaseMarginsArrayInterface();
@Override
public void close() throws Exception {}
public abstract class ColumnBatch extends Column {
/** Get features cuda array interface json string */
public abstract String toFeaturesJson();
}

View File

@ -195,7 +195,7 @@ public class DMatrix {
*/
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
long[] out = new long[1];
String json = columnBatch.getFeatureArrayInterface();
String json = columnBatch.toFeaturesJson();
if (json == null || json.isEmpty()) {
throw new XGBoostError("Expecting non-empty feature columns' array interface");
}
@ -228,7 +228,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setQueryId(Column column) throws XGBoostError {
setXGBDMatrixInfo("qid", column.getArrayInterfaceJson());
setXGBDMatrixInfo("qid", column.toJson());
}
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
@ -362,7 +362,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setLabel(Column column) throws XGBoostError {
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
setXGBDMatrixInfo("label", column.toJson());
}
/**
@ -393,7 +393,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setWeight(Column column) throws XGBoostError {
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
setXGBDMatrixInfo("weight", column.toJson());
}
/**
@ -421,7 +421,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setBaseMargin(Column column) throws XGBoostError {
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
setXGBDMatrixInfo("base_margin", column.toJson());
}
/**

View File

@ -104,7 +104,8 @@ void CopyInterface(std::vector<xgboost::ArrayInterface<1>> &interface_arr,
}
}
void CopyMetaInfo(Json *p_interface, dh::device_vector<float> *out, cudaStream_t stream) {
template <typename T>
void CopyMetaInfo(Json *p_interface, dh::device_vector<T> *out, cudaStream_t stream) {
auto &j_interface = *p_interface;
CHECK_EQ(get<Array const>(j_interface).size(), 1);
auto object = get<Object>(get<Array>(j_interface)[0]);
@ -151,9 +152,11 @@ class DataIteratorProxy {
std::vector<std::unique_ptr<dh::device_vector<float>>> labels_;
std::vector<std::unique_ptr<dh::device_vector<float>>> weights_;
std::vector<std::unique_ptr<dh::device_vector<float>>> base_margins_;
std::vector<std::unique_ptr<dh::device_vector<int>>> qids_;
std::vector<Json> label_interfaces_;
std::vector<Json> weight_interfaces_;
std::vector<Json> margin_interfaces_;
std::vector<Json> qid_interfaces_;
size_t it_{0};
size_t n_batches_{0};
@ -186,11 +189,11 @@ class DataIteratorProxy {
void StageMetaInfo(Json json_interface) {
CHECK(!IsA<Null>(json_interface));
auto json_map = get<Object const>(json_interface);
if (json_map.find("label_str") == json_map.cend()) {
if (json_map.find("label") == json_map.cend()) {
LOG(FATAL) << "Must have a label field.";
}
Json label = json_interface["label_str"];
Json label = json_interface["label"];
CHECK(!IsA<Null>(label));
labels_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
@ -200,8 +203,8 @@ class DataIteratorProxy {
Json::Dump(label, &str);
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
if (json_map.find("weight_str") != json_map.cend()) {
Json weight = json_interface["weight_str"];
if (json_map.find("weight") != json_map.cend()) {
Json weight = json_interface["weight"];
CHECK(!IsA<Null>(weight));
weights_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
@ -211,8 +214,8 @@ class DataIteratorProxy {
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
}
if (json_map.find("basemargin_str") != json_map.cend()) {
Json basemargin = json_interface["basemargin_str"];
if (json_map.find("baseMargin") != json_map.cend()) {
Json basemargin = json_interface["baseMargin"];
base_margins_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_);
margin_interfaces_.emplace_back(basemargin);
@ -220,6 +223,16 @@ class DataIteratorProxy {
Json::Dump(basemargin, &str);
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
}
if (json_map.find("qid") != json_map.cend()) {
Json qid = json_interface["qid"];
qids_.emplace_back(new dh::device_vector<int>);
CopyMetaInfo(&qid, qids_.back().get(), copy_stream_);
qid_interfaces_.emplace_back(qid);
Json::Dump(qid, &str);
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
}
}
void CloseJvmBatch() {
@ -249,11 +262,11 @@ class DataIteratorProxy {
// batch should be ColumnBatch from jvm
jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_);
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_);
jmethodID getArrayInterfaceJson = CheckJvmCall(jenv_->GetMethodID(
batch_class, "getArrayInterfaceJson", "()Ljava/lang/String;"), jenv_);
jmethodID toJson = CheckJvmCall(jenv_->GetMethodID(
batch_class, "toJson", "()Ljava/lang/String;"), jenv_);
auto jinterface =
static_cast<jstring>(jenv_->CallObjectMethod(batch, getArrayInterfaceJson));
static_cast<jstring>(jenv_->CallObjectMethod(batch, toJson));
CheckJvmCall(jinterface, jenv_);
char const *c_interface_str =
CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_);
@ -281,7 +294,7 @@ class DataIteratorProxy {
CHECK(!IsA<Null>(json_interface));
StageMetaInfo(json_interface);
Json features = json_interface["features_str"];
Json features = json_interface["features"];
auto json_columns = get<Array const>(features);
std::vector<ArrayInterface<1>> interfaces;
@ -337,6 +350,12 @@ class DataIteratorProxy {
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
}
if (n_batches_ == this->qid_interfaces_.size()) {
auto const &qid = this->qid_interfaces_.at(it_);
Json::Dump(qid, &str);
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
}
// Data
auto const &json_interface = host_columns_.at(it_)->interfaces;