[jvm-packages] remove xgboost4j-gpu and rework cudf column (#10630)
This commit is contained in:
parent
fcae6301ec
commit
d5834b68c3
@ -33,7 +33,6 @@ def main(args):
|
||||
for artifact in [
|
||||
"xgboost-jvm",
|
||||
"xgboost4j",
|
||||
"xgboost4j-gpu",
|
||||
"xgboost4j-spark",
|
||||
"xgboost4j-spark-gpu",
|
||||
"xgboost4j-flink",
|
||||
|
||||
@ -2,11 +2,11 @@ find_package(JNI REQUIRED)
|
||||
|
||||
list(APPEND JVM_SOURCES
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp)
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cpp)
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND JVM_SOURCES
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu)
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu)
|
||||
endif()
|
||||
|
||||
add_library(xgboost4j SHARED ${JVM_SOURCES} ${XGBOOST_OBJ_SOURCES})
|
||||
|
||||
@ -131,14 +131,6 @@ def native_build(args):
|
||||
run("cmake .. " + " ".join(args))
|
||||
run("cmake --build . --config Release" + maybe_parallel_build)
|
||||
|
||||
with cd("demo/CLI/regression"):
|
||||
run(f'"{sys.executable}" mapfeat.py')
|
||||
run(f'"{sys.executable}" mknfold.py machine.txt 1')
|
||||
|
||||
xgboost4j = "xgboost4j-gpu" if cli_args.use_cuda == "ON" else "xgboost4j"
|
||||
xgboost4j_spark = (
|
||||
"xgboost4j-spark-gpu" if cli_args.use_cuda == "ON" else "xgboost4j-spark"
|
||||
)
|
||||
|
||||
print("copying native library")
|
||||
library_name, os_folder = {
|
||||
@ -155,26 +147,34 @@ def native_build(args):
|
||||
"arm64": "aarch64", # on macOS & Windows ARM 64-bit
|
||||
"aarch64": "aarch64",
|
||||
}[platform.machine().lower()]
|
||||
output_folder = "{}/src/main/resources/lib/{}/{}".format(
|
||||
xgboost4j, os_folder, arch_folder
|
||||
output_folder = "xgboost4j/src/main/resources/lib/{}/{}".format(
|
||||
os_folder, arch_folder
|
||||
)
|
||||
maybe_makedirs(output_folder)
|
||||
cp("../lib/" + library_name, output_folder)
|
||||
|
||||
print("copying train/test files")
|
||||
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
|
||||
|
||||
# for xgboost4j
|
||||
maybe_makedirs("xgboost4j/src/test/resources")
|
||||
for file in glob.glob("../demo/data/agaricus.*"):
|
||||
cp(file, "xgboost4j/src/test/resources")
|
||||
|
||||
# for xgboost4j-spark
|
||||
maybe_makedirs("xgboost4j-spark/src/test/resources")
|
||||
with cd("../demo/CLI/regression"):
|
||||
run(f'"{sys.executable}" mapfeat.py')
|
||||
run(f'"{sys.executable}" mknfold.py machine.txt 1')
|
||||
|
||||
for file in glob.glob("../demo/CLI/regression/machine.txt.t*"):
|
||||
cp(file, "{}/src/test/resources".format(xgboost4j_spark))
|
||||
cp(file, "xgboost4j-spark/src/test/resources")
|
||||
for file in glob.glob("../demo/data/agaricus.*"):
|
||||
cp(file, "{}/src/test/resources".format(xgboost4j_spark))
|
||||
cp(file, "xgboost4j-spark/src/test/resources")
|
||||
|
||||
maybe_makedirs("{}/src/test/resources".format(xgboost4j))
|
||||
for file in glob.glob("../demo/data/agaricus.*"):
|
||||
cp(file, "{}/src/test/resources".format(xgboost4j))
|
||||
# for xgboost4j-spark-gpu
|
||||
if cli_args.use_cuda == "ON":
|
||||
maybe_makedirs("xgboost4j-spark-gpu/src/test/resources")
|
||||
for file in glob.glob("../demo/data/veterans_lung_cancer.csv"):
|
||||
cp(file, "xgboost4j-spark-gpu/src/test/resources")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -37,7 +37,7 @@
|
||||
<junit.version>4.13.2</junit.version>
|
||||
<spark.version>3.5.1</spark.version>
|
||||
<spark.version.gpu>3.5.1</spark.version.gpu>
|
||||
<fasterxml.jackson.version>2.17.2</fasterxml.jackson.version>
|
||||
<fasterxml.jackson.version>2.15.0</fasterxml.jackson.version>
|
||||
<scala.version>2.12.18</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>3.4.0</hadoop.version>
|
||||
@ -105,7 +105,7 @@
|
||||
<use.cuda>ON</use.cuda>
|
||||
</properties>
|
||||
<modules>
|
||||
<module>xgboost4j-gpu</module>
|
||||
<module>xgboost4j</module>
|
||||
<module>xgboost4j-spark-gpu</module>
|
||||
</modules>
|
||||
</profile>
|
||||
@ -117,7 +117,6 @@
|
||||
<module>xgboost4j-example</module>
|
||||
<module>xgboost4j-spark</module>
|
||||
<module>xgboost4j-flink</module>
|
||||
<module>xgboost4j-gpu</module>
|
||||
<module>xgboost4j-spark-gpu</module>
|
||||
</modules>
|
||||
<build>
|
||||
@ -243,7 +242,6 @@
|
||||
<module>xgboost4j-example</module>
|
||||
<module>xgboost4j-spark</module>
|
||||
<module>xgboost4j-flink</module>
|
||||
<module>xgboost4j-gpu</module>
|
||||
<module>xgboost4j-spark-gpu</module>
|
||||
</modules>
|
||||
<build>
|
||||
|
||||
@ -1,140 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>2.2.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-gpu_2.12</artifactId>
|
||||
<name>xgboost4j-gpu</name>
|
||||
<version>2.2.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-compiler</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang.modules</groupId>
|
||||
<artifactId>scala-collection-compat_${scala.binary.version}</artifactId>
|
||||
<version>${scala-collection-compat.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.rapids</groupId>
|
||||
<artifactId>cudf</artifactId>
|
||||
<version>${cudf.version}</version>
|
||||
<classifier>${cudf.classifier}</classifier>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-hdfs</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>${scalatest.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.14.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-javadoc-plugin</artifactId>
|
||||
<version>3.7.0</version>
|
||||
<configuration>
|
||||
<show>protected</show>
|
||||
<nohelp>true</nohelp>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<skipAssembly>false</skipAssembly>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<artifactId>exec-maven-plugin</artifactId>
|
||||
<groupId>org.codehaus.mojo</groupId>
|
||||
<version>3.3.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>native</id>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>exec</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<executable>python</executable>
|
||||
<arguments>
|
||||
<argument>create_jni.py</argument>
|
||||
<argument>--log-capi-invocation</argument>
|
||||
<argument>${log.capi.invocation}</argument>
|
||||
<argument>--use-cuda</argument>
|
||||
<argument>${use.cuda}</argument>
|
||||
</arguments>
|
||||
<workingDirectory>${user.dir}</workingDirectory>
|
||||
<skip>${skip.native.build}</skip>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<version>3.4.1</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>test-jar</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-resources-plugin</artifactId>
|
||||
<version>3.3.1</version>
|
||||
<configuration>
|
||||
<nonFilteredFileExtensions>
|
||||
<nonFilteredFileExtension>dll</nonFilteredFileExtension>
|
||||
<nonFilteredFileExtension>dylib</nonFilteredFileExtension>
|
||||
<nonFilteredFileExtension>so</nonFilteredFileExtension>
|
||||
</nonFilteredFileExtensions>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
||||
@ -1,110 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
|
||||
import ai.rapids.cudf.BufferType;
|
||||
import ai.rapids.cudf.ColumnVector;
|
||||
import ai.rapids.cudf.DType;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Column;
|
||||
|
||||
/**
|
||||
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
|
||||
* It will be used to generate the cuda array interface.
|
||||
*/
|
||||
public class CudfColumn extends Column {
|
||||
|
||||
private final long dataPtr; // gpu data buffer address
|
||||
private final long shape; // row count
|
||||
private final long validPtr; // gpu valid buffer address
|
||||
private final int typeSize; // type size in bytes
|
||||
private final String typeStr; // follow array interface spec
|
||||
private final long nullCount; // null count
|
||||
|
||||
private String arrayInterface = null; // the cuda array interface
|
||||
|
||||
public static CudfColumn from(ColumnVector cv) {
|
||||
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
|
||||
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
|
||||
long validPtr = 0;
|
||||
if (validBuffer != null) {
|
||||
validPtr = validBuffer.getAddress();
|
||||
}
|
||||
DType dType = cv.getType();
|
||||
String typeStr = "";
|
||||
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
|
||||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
|
||||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
|
||||
dType == DType.TIMESTAMP_SECONDS) {
|
||||
typeStr = "<f" + dType.getSizeInBytes();
|
||||
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
|
||||
dType == DType.INT32 || dType == DType.INT64) {
|
||||
typeStr = "<i" + dType.getSizeInBytes();
|
||||
} else {
|
||||
// Unsupported type.
|
||||
throw new IllegalArgumentException("Unsupported data type: " + dType);
|
||||
}
|
||||
|
||||
return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
|
||||
dType.getSizeInBytes(), typeStr, cv.getNullCount());
|
||||
}
|
||||
|
||||
private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
|
||||
long nullCount) {
|
||||
this.dataPtr = dataPtr;
|
||||
this.shape = shape;
|
||||
this.validPtr = validPtr;
|
||||
this.typeSize = typeSize;
|
||||
this.typeStr = typeStr;
|
||||
this.nullCount = nullCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getArrayInterfaceJson() {
|
||||
// There is no race-condition
|
||||
if (arrayInterface == null) {
|
||||
arrayInterface = CudfUtils.buildArrayInterface(this);
|
||||
}
|
||||
return arrayInterface;
|
||||
}
|
||||
|
||||
public long getDataPtr() {
|
||||
return dataPtr;
|
||||
}
|
||||
|
||||
public long getShape() {
|
||||
return shape;
|
||||
}
|
||||
|
||||
public long getValidPtr() {
|
||||
return validPtr;
|
||||
}
|
||||
|
||||
public int getTypeSize() {
|
||||
return typeSize;
|
||||
}
|
||||
|
||||
public String getTypeStr() {
|
||||
return typeStr;
|
||||
}
|
||||
|
||||
public long getNullCount() {
|
||||
return nullCount;
|
||||
}
|
||||
|
||||
}
|
||||
@ -1,88 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
|
||||
/**
|
||||
* Class to wrap CUDF Table to generate the cuda array interface.
|
||||
*/
|
||||
public class CudfColumnBatch extends ColumnBatch {
|
||||
private final Table feature;
|
||||
private final Table label;
|
||||
private final Table weight;
|
||||
private final Table baseMargin;
|
||||
|
||||
public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
|
||||
this.feature = feature;
|
||||
this.label = labels;
|
||||
this.weight = weights;
|
||||
this.baseMargin = baseMargins;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFeatureArrayInterface() {
|
||||
return getArrayInterface(this.feature);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getLabelsArrayInterface() {
|
||||
return getArrayInterface(this.label);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWeightsArrayInterface() {
|
||||
return getArrayInterface(this.weight);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getBaseMarginsArrayInterface() {
|
||||
return getArrayInterface(this.baseMargin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (feature != null) feature.close();
|
||||
if (label != null) label.close();
|
||||
if (weight != null) weight.close();
|
||||
if (baseMargin != null) baseMargin.close();
|
||||
}
|
||||
|
||||
private String getArrayInterface(Table table) {
|
||||
if (table == null || table.getNumberOfColumns() == 0) {
|
||||
return "";
|
||||
}
|
||||
return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
|
||||
}
|
||||
|
||||
private CudfColumn[] getAsCudfColumn(Table table) {
|
||||
if (table == null || table.getNumberOfColumns() == 0) {
|
||||
// This will never happen.
|
||||
return new CudfColumn[]{};
|
||||
}
|
||||
|
||||
return IntStream.range(0, table.getNumberOfColumns())
|
||||
.mapToObj((i) -> table.getColumn(i))
|
||||
.map(CudfColumn::from)
|
||||
.toArray(CudfColumn[]::new);
|
||||
}
|
||||
|
||||
}
|
||||
@ -1,98 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* Cudf utilities to build cuda array interface against {@link CudfColumn}
|
||||
*/
|
||||
class CudfUtils {
|
||||
|
||||
/**
|
||||
* Build the cuda array interface based on CudfColumn(s)
|
||||
* @param cudfColumns the CudfColumn(s) to be built
|
||||
* @return the json format of cuda array interface
|
||||
*/
|
||||
public static String buildArrayInterface(CudfColumn... cudfColumns) {
|
||||
return new Builder().add(cudfColumns).build();
|
||||
}
|
||||
|
||||
// Helper class to build array interface string
|
||||
private static class Builder {
|
||||
private ArrayList<String> colArrayInterfaces = new ArrayList<String>();
|
||||
|
||||
private Builder add(CudfColumn... columns) {
|
||||
if (columns == null || columns.length <= 0) {
|
||||
throw new IllegalArgumentException("At least one ColumnData is required.");
|
||||
}
|
||||
for (CudfColumn cd : columns) {
|
||||
colArrayInterfaces.add(buildColumnObject(cd));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
private String build() {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("[");
|
||||
for (int i = 0; i < colArrayInterfaces.size(); i++) {
|
||||
builder.append(colArrayInterfaces.get(i));
|
||||
if (i != colArrayInterfaces.size() - 1) {
|
||||
builder.append(",");
|
||||
}
|
||||
}
|
||||
builder.append("]");
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
/** build the whole column information including data and valid info */
|
||||
private String buildColumnObject(CudfColumn column) {
|
||||
if (column.getDataPtr() == 0) {
|
||||
throw new IllegalArgumentException("Empty column data is NOT accepted!");
|
||||
}
|
||||
if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) {
|
||||
throw new IllegalArgumentException("Empty type string is NOT accepted!");
|
||||
}
|
||||
|
||||
StringBuilder builder = new StringBuilder();
|
||||
String colData = buildMetaObject(column.getDataPtr(), column.getShape(),
|
||||
column.getTypeStr());
|
||||
builder.append("{");
|
||||
builder.append(colData);
|
||||
if (column.getValidPtr() != 0 && column.getNullCount() != 0) {
|
||||
String validString = buildMetaObject(column.getValidPtr(), column.getShape(), "<t1");
|
||||
builder.append(",\"mask\":");
|
||||
builder.append("{");
|
||||
builder.append(validString);
|
||||
builder.append("}");
|
||||
}
|
||||
builder.append("}");
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
/** build the base information of a column */
|
||||
private String buildMetaObject(long ptr, long shape, final String typeStr) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("\"shape\":[" + shape + "],");
|
||||
builder.append("\"data\":[" + ptr + "," + "false" + "],");
|
||||
builder.append("\"typestr\":\"" + typeStr + "\",");
|
||||
builder.append("\"version\":" + 1);
|
||||
return builder.toString();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -1 +0,0 @@
|
||||
../../../../../../../xgboost4j/src/main/java/ml/dmlc/xgboost4j/java
|
||||
@ -1 +0,0 @@
|
||||
../../../../xgboost4j/src/main/resources/xgboost4j-version.properties
|
||||
@ -1 +0,0 @@
|
||||
../../../xgboost4j/src/main/scala/
|
||||
@ -24,7 +24,7 @@
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-gpu_2.12</artifactId>
|
||||
<artifactId>xgboost4j_2.12</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
@ -51,5 +51,17 @@
|
||||
<version>${spark.rapids.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${fasterxml.jackson.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@ -0,0 +1,117 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
|
||||
import ai.rapids.cudf.ColumnVector;
|
||||
import ai.rapids.cudf.DType;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* CudfColumn is the CUDF column representing, providing the cuda array interface
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public class CudfColumn extends Column {
|
||||
private List<Long> shape = new ArrayList<>(); // row count
|
||||
private List<Object> data = new ArrayList<>(); // gpu data buffer address
|
||||
private String typestr;
|
||||
private int version = 1;
|
||||
private CudfColumn mask = null;
|
||||
|
||||
public CudfColumn(long shape, long data, String typestr, int version) {
|
||||
this.shape.add(shape);
|
||||
this.data.add(data);
|
||||
this.data.add(false);
|
||||
this.typestr = typestr;
|
||||
this.version = version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create CudfColumn according to ColumnVector
|
||||
*/
|
||||
public static CudfColumn from(ColumnVector cv) {
|
||||
BaseDeviceMemoryBuffer dataBuffer = cv.getData();
|
||||
assert dataBuffer != null;
|
||||
|
||||
DType dType = cv.getType();
|
||||
String typeStr = "";
|
||||
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
|
||||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
|
||||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
|
||||
dType == DType.TIMESTAMP_SECONDS) {
|
||||
typeStr = "<f" + dType.getSizeInBytes();
|
||||
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
|
||||
dType == DType.INT32 || dType == DType.INT64) {
|
||||
typeStr = "<i" + dType.getSizeInBytes();
|
||||
} else {
|
||||
// Unsupported type.
|
||||
throw new IllegalArgumentException("Unsupported data type: " + dType);
|
||||
}
|
||||
|
||||
CudfColumn data = new CudfColumn(cv.getRowCount(), dataBuffer.getAddress(), typeStr, 1);
|
||||
|
||||
BaseDeviceMemoryBuffer validBuffer = cv.getValid();
|
||||
if (validBuffer != null && cv.getNullCount() != 0) {
|
||||
CudfColumn mask = new CudfColumn(cv.getRowCount(), validBuffer.getAddress(), "<t1", 1);
|
||||
data.setMask(mask);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
public List<Long> getShape() {
|
||||
return shape;
|
||||
}
|
||||
|
||||
public List<Object> getData() {
|
||||
return data;
|
||||
}
|
||||
|
||||
public String getTypestr() {
|
||||
return typestr;
|
||||
}
|
||||
|
||||
public int getVersion() {
|
||||
return version;
|
||||
}
|
||||
|
||||
public CudfColumn getMask() {
|
||||
return mask;
|
||||
}
|
||||
|
||||
public void setMask(CudfColumn mask) {
|
||||
this.mask = mask;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toJson() {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
|
||||
try {
|
||||
List<CudfColumn> objects = new ArrayList<>(1);
|
||||
objects.add(this);
|
||||
return mapper.writeValueAsString(objects);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,137 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* CudfColumnBatch wraps multiple CudfColumns to provide the cuda
|
||||
* array interface json string for all columns.
|
||||
*/
|
||||
public class CudfColumnBatch extends ColumnBatch {
|
||||
@JsonIgnore
|
||||
private final Table featureTable;
|
||||
@JsonIgnore
|
||||
private final Table labelTable;
|
||||
@JsonIgnore
|
||||
private final Table weightTable;
|
||||
@JsonIgnore
|
||||
private final Table baseMarginTable;
|
||||
@JsonIgnore
|
||||
private final Table qidTable;
|
||||
|
||||
private List<CudfColumn> features;
|
||||
private List<CudfColumn> label;
|
||||
private List<CudfColumn> weight;
|
||||
private List<CudfColumn> baseMargin;
|
||||
private List<CudfColumn> qid;
|
||||
|
||||
public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
|
||||
Table baseMarginTable, Table qidTable) {
|
||||
this.featureTable = featureTable;
|
||||
this.labelTable = labelTable;
|
||||
this.weightTable = weightTable;
|
||||
this.baseMarginTable = baseMarginTable;
|
||||
this.qidTable = qidTable;
|
||||
|
||||
features = initializeCudfColumns(featureTable);
|
||||
if (labelTable != null) {
|
||||
assert labelTable.getNumberOfColumns() == 1;
|
||||
label = initializeCudfColumns(labelTable);
|
||||
}
|
||||
|
||||
if (weightTable != null) {
|
||||
assert weightTable.getNumberOfColumns() == 1;
|
||||
weight = initializeCudfColumns(weightTable);
|
||||
}
|
||||
|
||||
if (baseMarginTable != null) {
|
||||
baseMargin = initializeCudfColumns(baseMarginTable);
|
||||
}
|
||||
|
||||
if (qidTable != null) {
|
||||
qid = initializeCudfColumns(qidTable);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private List<CudfColumn> initializeCudfColumns(Table table) {
|
||||
assert table != null && table.getNumberOfColumns() > 0;
|
||||
|
||||
return IntStream.range(0, table.getNumberOfColumns())
|
||||
.mapToObj(table::getColumn)
|
||||
.map(CudfColumn::from)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<CudfColumn> getFeatures() {
|
||||
return features;
|
||||
}
|
||||
|
||||
public List<CudfColumn> getLabel() {
|
||||
return label;
|
||||
}
|
||||
|
||||
public List<CudfColumn> getWeight() {
|
||||
return weight;
|
||||
}
|
||||
|
||||
public List<CudfColumn> getBaseMargin() {
|
||||
return baseMargin;
|
||||
}
|
||||
|
||||
public List<CudfColumn> getQid() {
|
||||
return qid;
|
||||
}
|
||||
|
||||
public String toJson() {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
|
||||
try {
|
||||
return mapper.writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toFeaturesJson() {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
try {
|
||||
return mapper.writeValueAsString(features);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (featureTable != null) featureTable.close();
|
||||
if (labelTable != null) labelTable.close();
|
||||
if (weightTable != null) weightTable.close();
|
||||
if (baseMarginTable != null) baseMarginTable.close();
|
||||
if (qidTable != null) qidTable.close();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,90 @@
|
||||
/*
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* QuantileDMatrix will only be used to train
|
||||
*/
|
||||
public class QuantileDMatrix extends DMatrix {
|
||||
/**
|
||||
* Create QuantileDMatrix from iterator based on the cuda array interface
|
||||
*
|
||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||
* @param missing the missing value
|
||||
* @param maxBin the max bin
|
||||
* @param nthread the parallelism
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public QuantileDMatrix(
|
||||
Iterator<ColumnBatch> iter,
|
||||
float missing,
|
||||
int maxBin,
|
||||
int nthread) throws XGBoostError {
|
||||
super(0);
|
||||
long[] out = new long[1];
|
||||
String conf = getConfig(missing, maxBin, nthread);
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
|
||||
iter, null, conf, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
|
||||
}
|
||||
|
||||
private String getConfig(float missing, int maxBin, int nthread) {
|
||||
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
|
||||
missing, maxBin, nthread);
|
||||
}
|
||||
}
|
||||
@ -17,14 +17,12 @@
|
||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
|
||||
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.rdd.RDD
|
||||
@ -325,7 +323,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
throw new RuntimeException("Something wrong for feature indices")
|
||||
}
|
||||
try {
|
||||
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null)
|
||||
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null, null)
|
||||
val dm = new DMatrix(cudfColumnBatch, missing, 1)
|
||||
if (dm == null) {
|
||||
Iterator.empty
|
||||
@ -586,7 +584,8 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(indices.featureIds).asJava),
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(Seq(indices.labelId)).asJava),
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(weights).asJava),
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava));
|
||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava),
|
||||
null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.HashMap;
|
||||
@ -22,31 +22,21 @@ import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import ai.rapids.cudf.*;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import ai.rapids.cudf.DType;
|
||||
import ai.rapids.cudf.Schema;
|
||||
import ai.rapids.cudf.Table;
|
||||
import ai.rapids.cudf.ColumnVector;
|
||||
import ai.rapids.cudf.CSVOptions;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* Tests the BoosterTest trained by DMatrix
|
||||
*
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public class BoosterTest {
|
||||
|
||||
@Test
|
||||
public void testBooster() throws XGBoostError {
|
||||
String trainingDataPath = "../../demo/data/veterans_lung_cancer.csv";
|
||||
String trainingDataPath = getClass().getClassLoader()
|
||||
.getResource("veterans_lung_cancer.csv").getPath();
|
||||
Schema schema = Schema.builder()
|
||||
.column(DType.FLOAT32, "A")
|
||||
.column(DType.FLOAT32, "B")
|
||||
@ -95,7 +85,7 @@ public class BoosterTest {
|
||||
|
||||
try (Table y = new Table(labels);) {
|
||||
|
||||
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null);
|
||||
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null, null);
|
||||
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
|
||||
|
||||
//set watchList
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,24 +14,15 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
|
||||
import com.google.common.primitives.Floats;
|
||||
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.junit.Test;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
@ -43,24 +34,29 @@ public class DMatrixTest {
|
||||
@Test
|
||||
public void testCreateFromArrayInterfaceColumns() {
|
||||
Float[] labelFloats = new Float[]{2f, 4f, 6f, 8f, 10f};
|
||||
Integer[] groups = new Integer[]{1, 1, 7, 7, 19, 26};
|
||||
int[] expectedGroup = new int[]{0, 2, 4, 5, 6};
|
||||
|
||||
Throwable ex = null;
|
||||
try (
|
||||
Table X = new Table.TestBuilder().column(1.f, null, 5.f, 7.f, 9.f).build();
|
||||
Table y = new Table.TestBuilder().column(labelFloats).build();
|
||||
Table w = new Table.TestBuilder().column(labelFloats).build();
|
||||
Table q = new Table.TestBuilder().column(groups).build();
|
||||
Table margin = new Table.TestBuilder().column(labelFloats).build();) {
|
||||
|
||||
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null);
|
||||
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null, null);
|
||||
|
||||
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
|
||||
CudfColumn weightColumn = CudfColumn.from(w.getColumn(0));
|
||||
CudfColumn baseMarginColumn = CudfColumn.from(margin.getColumn(0));
|
||||
CudfColumn qidColumn = CudfColumn.from(q.getColumn(0));
|
||||
|
||||
DMatrix dMatrix = new DMatrix(cudfDataFrame, 0, 1);
|
||||
dMatrix.setLabel(labelColumn);
|
||||
dMatrix.setWeight(weightColumn);
|
||||
dMatrix.setBaseMargin(baseMarginColumn);
|
||||
dMatrix.setQueryId(qidColumn);
|
||||
|
||||
String[] featureNames = new String[]{"f1"};
|
||||
dMatrix.setFeatureNames(featureNames);
|
||||
@ -76,10 +72,12 @@ public class DMatrixTest {
|
||||
float[] label = dMatrix.getLabel();
|
||||
float[] weight = dMatrix.getWeight();
|
||||
float[] baseMargin = dMatrix.getBaseMargin();
|
||||
int[] group = dMatrix.getGroup();
|
||||
|
||||
TestCase.assertTrue(Arrays.equals(anchor, label));
|
||||
TestCase.assertTrue(Arrays.equals(anchor, weight));
|
||||
TestCase.assertTrue(Arrays.equals(anchor, baseMargin));
|
||||
TestCase.assertTrue(Arrays.equals(expectedGroup, group));
|
||||
} catch (Throwable e) {
|
||||
ex = e;
|
||||
e.printStackTrace();
|
||||
@ -93,10 +91,14 @@ public class DMatrixTest {
|
||||
Float[] label1 = {25f, 21f, 22f, 20f, 24f};
|
||||
Float[] weight1 = {1.3f, 2.31f, 0.32f, 3.3f, 1.34f};
|
||||
Float[] baseMargin1 = {1.2f, 0.2f, 1.3f, 2.4f, 3.5f};
|
||||
Integer[] groups1 = new Integer[]{1, 1, 7, 7, 19, 26};
|
||||
|
||||
Float[] label2 = {9f, 5f, 4f, 10f, 12f};
|
||||
Float[] weight2 = {3.0f, 1.3f, 3.2f, 0.3f, 1.34f};
|
||||
Float[] baseMargin2 = {0.2f, 2.5f, 3.1f, 4.4f, 2.2f};
|
||||
Integer[] groups2 = new Integer[]{30, 30, 30, 40, 40};
|
||||
|
||||
int[] expectedGroup = new int[]{0, 2, 4, 5, 6, 9, 11};
|
||||
|
||||
try (
|
||||
Table X_0 = new Table.TestBuilder()
|
||||
@ -106,30 +108,47 @@ public class DMatrixTest {
|
||||
Table y_0 = new Table.TestBuilder().column(label1).build();
|
||||
Table w_0 = new Table.TestBuilder().column(weight1).build();
|
||||
Table m_0 = new Table.TestBuilder().column(baseMargin1).build();
|
||||
Table q_0 = new Table.TestBuilder().column(groups1).build();
|
||||
|
||||
Table X_1 = new Table.TestBuilder().column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f)
|
||||
.column(1.2f, 1.4f, null, 12.6f, 10.10f).build();
|
||||
Table y_1 = new Table.TestBuilder().column(label2).build();
|
||||
Table w_1 = new Table.TestBuilder().column(weight2).build();
|
||||
Table m_1 = new Table.TestBuilder().column(baseMargin2).build();) {
|
||||
Table q_1 = new Table.TestBuilder().column(groups2).build();
|
||||
|
||||
List<ColumnBatch> tables = new LinkedList<>();
|
||||
|
||||
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
|
||||
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
|
||||
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0));
|
||||
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1, q_1));
|
||||
|
||||
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
|
||||
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
|
||||
|
||||
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
|
||||
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
|
||||
float[] anchorBaseMargin = convertFloatTofloat((Float[]) ArrayUtils.addAll(baseMargin1, baseMargin2));
|
||||
float[] anchorLabel = convertFloatTofloat(label1, label2);
|
||||
float[] anchorWeight = convertFloatTofloat(weight1, weight2);
|
||||
float[] anchorBaseMargin = convertFloatTofloat(baseMargin1, baseMargin2);
|
||||
|
||||
TestCase.assertTrue(Arrays.equals(anchorLabel, dmat.getLabel()));
|
||||
TestCase.assertTrue(Arrays.equals(anchorWeight, dmat.getWeight()));
|
||||
TestCase.assertTrue(Arrays.equals(anchorBaseMargin, dmat.getBaseMargin()));
|
||||
TestCase.assertTrue(Arrays.equals(expectedGroup, dmat.getGroup()));
|
||||
}
|
||||
}
|
||||
|
||||
private float[] convertFloatTofloat(Float[] in) {
|
||||
return Floats.toArray(Arrays.asList(in));
|
||||
private float[] convertFloatTofloat(Float[]... datas) {
|
||||
int totalLength = 0;
|
||||
for (Float[] data : datas) {
|
||||
totalLength += data.length;
|
||||
}
|
||||
float[] floatArray = new float[totalLength];
|
||||
int index = 0;
|
||||
for (Float[] data : datas) {
|
||||
for (int i = 0; i < data.length; i++) {
|
||||
floatArray[i + index] = data[i];
|
||||
}
|
||||
index += data.length;
|
||||
}
|
||||
return floatArray;
|
||||
}
|
||||
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -16,11 +16,11 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import ai.rapids.cudf.Table
|
||||
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
class QuantileDMatrixSuite extends AnyFunSuite {
|
||||
|
||||
@ -29,10 +29,14 @@ class QuantileDMatrixSuite extends AnyFunSuite {
|
||||
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
|
||||
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
|
||||
val baseMargin1 = Array[java.lang.Float](1.2f, 0.2f, 1.3f, 2.4f, 3.5f)
|
||||
val group1 = Array[java.lang.Integer](1, 1, 7, 7, 19, 26)
|
||||
|
||||
val label2 = Array[java.lang.Float](9f, 5f, 4f, 10f, 12f)
|
||||
val weight2 = Array[java.lang.Float](3.0f, 1.3f, 3.2f, 0.3f, 1.34f)
|
||||
val baseMargin2 = Array[java.lang.Float](0.2f, 2.5f, 3.1f, 4.4f, 2.2f)
|
||||
val group2 = Array[java.lang.Integer](30, 30, 30, 40, 40)
|
||||
|
||||
val expectedGroup = Array(0, 2, 4, 5, 6, 9, 11)
|
||||
|
||||
withResource(new Table.TestBuilder()
|
||||
.column(1.2f, null.asInstanceOf[java.lang.Float], 5.2f, 7.2f, 9.2f)
|
||||
@ -41,20 +45,25 @@ class QuantileDMatrixSuite extends AnyFunSuite {
|
||||
withResource(new Table.TestBuilder().column(label1: _*).build) { y_0 =>
|
||||
withResource(new Table.TestBuilder().column(weight1: _*).build) { w_0 =>
|
||||
withResource(new Table.TestBuilder().column(baseMargin1: _*).build) { m_0 =>
|
||||
withResource(new Table.TestBuilder().column(group1: _*).build) { q_0 =>
|
||||
withResource(new Table.TestBuilder()
|
||||
.column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float])
|
||||
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build)
|
||||
{ X_1 =>
|
||||
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build) {
|
||||
X_1 =>
|
||||
withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 =>
|
||||
withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 =>
|
||||
withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 =>
|
||||
withResource(new Table.TestBuilder().column(group2: _*).build) { q_2 =>
|
||||
val batches = new ArrayBuffer[CudfColumnBatch]()
|
||||
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
|
||||
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
|
||||
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0)
|
||||
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1, q_2)
|
||||
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
|
||||
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
|
||||
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
|
||||
assert(dmatrix.getGroup().sameElements(expectedGroup))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -73,6 +82,4 @@ class QuantileDMatrixSuite extends AnyFunSuite {
|
||||
r.close()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -96,6 +96,8 @@
|
||||
<argument>create_jni.py</argument>
|
||||
<argument>--log-capi-invocation</argument>
|
||||
<argument>${log.capi.invocation}</argument>
|
||||
<argument>--use-cuda</argument>
|
||||
<argument>${use.cuda}</argument>
|
||||
</arguments>
|
||||
<workingDirectory>${user.dir}</workingDirectory>
|
||||
<skip>${skip.native.build}</skip>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -17,24 +17,17 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost Column to get the cuda array interface which is used to
|
||||
* set the information for DMatrix.
|
||||
*
|
||||
* This Column abstraction provides an array interface JSON string, which is
|
||||
* used to reconstruct columnar data within the XGBoost library.
|
||||
*/
|
||||
public abstract class Column implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Get the cuda array interface json string for the Column which can be representing
|
||||
* weight, label, base margin column.
|
||||
*
|
||||
* This API will be called by
|
||||
* {@link DMatrix#setLabel(Column)}
|
||||
* {@link DMatrix#setWeight(Column)}
|
||||
* {@link DMatrix#setBaseMargin(Column)}
|
||||
* Return array interface json string for this Column
|
||||
*/
|
||||
public abstract String getArrayInterfaceJson();
|
||||
public abstract String toJson();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
|
||||
public void close() throws Exception {
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -16,78 +16,12 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
|
||||
* For example, the cuDF dataframe which employs apache arrow specification.
|
||||
* This class wraps multiple Column and provides the array interface json
|
||||
* for all columns.
|
||||
*/
|
||||
public abstract class ColumnBatch implements AutoCloseable {
|
||||
/**
|
||||
* Get the cuda array interface json string for the whole ColumnBatch including
|
||||
* the must-have feature, label columns and the optional weight, base margin columns.
|
||||
*
|
||||
* This function is be called by native code during iteration and can be made as private
|
||||
* method. We keep it as public simply to silent the linter.
|
||||
*/
|
||||
public final String getArrayInterfaceJson() {
|
||||
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("{");
|
||||
String featureStr = this.getFeatureArrayInterface();
|
||||
if (featureStr == null || featureStr.isEmpty()) {
|
||||
throw new RuntimeException("Feature array interface must not be empty");
|
||||
} else {
|
||||
builder.append("\"features_str\":" + featureStr);
|
||||
}
|
||||
|
||||
String labelStr = this.getLabelsArrayInterface();
|
||||
if (labelStr == null || labelStr.isEmpty()) {
|
||||
throw new RuntimeException("Label array interface must not be empty");
|
||||
} else {
|
||||
builder.append(",\"label_str\":" + labelStr);
|
||||
}
|
||||
|
||||
String weightStr = getWeightsArrayInterface();
|
||||
if (weightStr != null && ! weightStr.isEmpty()) {
|
||||
builder.append(",\"weight_str\":" + weightStr);
|
||||
}
|
||||
|
||||
String baseMarginStr = getBaseMarginsArrayInterface();
|
||||
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
|
||||
builder.append(",\"basemargin_str\":" + baseMarginStr);
|
||||
}
|
||||
|
||||
builder.append("}");
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the feature columns.
|
||||
* The returned value must not be null or empty
|
||||
*/
|
||||
public abstract String getFeatureArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the label columns.
|
||||
* The returned value must not be null or empty if we're creating
|
||||
* {@link QuantileDMatrix#QuantileDMatrix(Iterator, float, int, int)}
|
||||
*/
|
||||
public abstract String getLabelsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the weight columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getWeightsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the base margin columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getBaseMarginsArrayInterface();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
public abstract class ColumnBatch extends Column {
|
||||
|
||||
/** Get features cuda array interface json string */
|
||||
public abstract String toFeaturesJson();
|
||||
}
|
||||
|
||||
@ -195,7 +195,7 @@ public class DMatrix {
|
||||
*/
|
||||
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
String json = columnBatch.getFeatureArrayInterface();
|
||||
String json = columnBatch.toFeaturesJson();
|
||||
if (json == null || json.isEmpty()) {
|
||||
throw new XGBoostError("Expecting non-empty feature columns' array interface");
|
||||
}
|
||||
@ -228,7 +228,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setQueryId(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("qid", column.getArrayInterfaceJson());
|
||||
setXGBDMatrixInfo("qid", column.toJson());
|
||||
}
|
||||
|
||||
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
|
||||
@ -362,7 +362,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
|
||||
setXGBDMatrixInfo("label", column.toJson());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -393,7 +393,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
|
||||
setXGBDMatrixInfo("weight", column.toJson());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -421,7 +421,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
|
||||
setXGBDMatrixInfo("base_margin", column.toJson());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -104,7 +104,8 @@ void CopyInterface(std::vector<xgboost::ArrayInterface<1>> &interface_arr,
|
||||
}
|
||||
}
|
||||
|
||||
void CopyMetaInfo(Json *p_interface, dh::device_vector<float> *out, cudaStream_t stream) {
|
||||
template <typename T>
|
||||
void CopyMetaInfo(Json *p_interface, dh::device_vector<T> *out, cudaStream_t stream) {
|
||||
auto &j_interface = *p_interface;
|
||||
CHECK_EQ(get<Array const>(j_interface).size(), 1);
|
||||
auto object = get<Object>(get<Array>(j_interface)[0]);
|
||||
@ -151,9 +152,11 @@ class DataIteratorProxy {
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> labels_;
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> weights_;
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> base_margins_;
|
||||
std::vector<std::unique_ptr<dh::device_vector<int>>> qids_;
|
||||
std::vector<Json> label_interfaces_;
|
||||
std::vector<Json> weight_interfaces_;
|
||||
std::vector<Json> margin_interfaces_;
|
||||
std::vector<Json> qid_interfaces_;
|
||||
|
||||
size_t it_{0};
|
||||
size_t n_batches_{0};
|
||||
@ -186,11 +189,11 @@ class DataIteratorProxy {
|
||||
void StageMetaInfo(Json json_interface) {
|
||||
CHECK(!IsA<Null>(json_interface));
|
||||
auto json_map = get<Object const>(json_interface);
|
||||
if (json_map.find("label_str") == json_map.cend()) {
|
||||
if (json_map.find("label") == json_map.cend()) {
|
||||
LOG(FATAL) << "Must have a label field.";
|
||||
}
|
||||
|
||||
Json label = json_interface["label_str"];
|
||||
Json label = json_interface["label"];
|
||||
CHECK(!IsA<Null>(label));
|
||||
labels_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
|
||||
@ -200,8 +203,8 @@ class DataIteratorProxy {
|
||||
Json::Dump(label, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
||||
|
||||
if (json_map.find("weight_str") != json_map.cend()) {
|
||||
Json weight = json_interface["weight_str"];
|
||||
if (json_map.find("weight") != json_map.cend()) {
|
||||
Json weight = json_interface["weight"];
|
||||
CHECK(!IsA<Null>(weight));
|
||||
weights_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
|
||||
@ -211,8 +214,8 @@ class DataIteratorProxy {
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
||||
}
|
||||
|
||||
if (json_map.find("basemargin_str") != json_map.cend()) {
|
||||
Json basemargin = json_interface["basemargin_str"];
|
||||
if (json_map.find("baseMargin") != json_map.cend()) {
|
||||
Json basemargin = json_interface["baseMargin"];
|
||||
base_margins_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_);
|
||||
margin_interfaces_.emplace_back(basemargin);
|
||||
@ -220,6 +223,16 @@ class DataIteratorProxy {
|
||||
Json::Dump(basemargin, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
||||
}
|
||||
|
||||
if (json_map.find("qid") != json_map.cend()) {
|
||||
Json qid = json_interface["qid"];
|
||||
qids_.emplace_back(new dh::device_vector<int>);
|
||||
CopyMetaInfo(&qid, qids_.back().get(), copy_stream_);
|
||||
qid_interfaces_.emplace_back(qid);
|
||||
|
||||
Json::Dump(qid, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void CloseJvmBatch() {
|
||||
@ -249,11 +262,11 @@ class DataIteratorProxy {
|
||||
// batch should be ColumnBatch from jvm
|
||||
jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_);
|
||||
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_);
|
||||
jmethodID getArrayInterfaceJson = CheckJvmCall(jenv_->GetMethodID(
|
||||
batch_class, "getArrayInterfaceJson", "()Ljava/lang/String;"), jenv_);
|
||||
jmethodID toJson = CheckJvmCall(jenv_->GetMethodID(
|
||||
batch_class, "toJson", "()Ljava/lang/String;"), jenv_);
|
||||
|
||||
auto jinterface =
|
||||
static_cast<jstring>(jenv_->CallObjectMethod(batch, getArrayInterfaceJson));
|
||||
static_cast<jstring>(jenv_->CallObjectMethod(batch, toJson));
|
||||
CheckJvmCall(jinterface, jenv_);
|
||||
char const *c_interface_str =
|
||||
CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_);
|
||||
@ -281,7 +294,7 @@ class DataIteratorProxy {
|
||||
CHECK(!IsA<Null>(json_interface));
|
||||
StageMetaInfo(json_interface);
|
||||
|
||||
Json features = json_interface["features_str"];
|
||||
Json features = json_interface["features"];
|
||||
auto json_columns = get<Array const>(features);
|
||||
std::vector<ArrayInterface<1>> interfaces;
|
||||
|
||||
@ -337,6 +350,12 @@ class DataIteratorProxy {
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
||||
}
|
||||
|
||||
if (n_batches_ == this->qid_interfaces_.size()) {
|
||||
auto const &qid = this->qid_interfaces_.at(it_);
|
||||
Json::Dump(qid, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
|
||||
}
|
||||
|
||||
// Data
|
||||
auto const &json_interface = host_columns_.at(it_)->interfaces;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user