[jvm-packages] remove xgboost4j-gpu and rework cudf column (#10630)
This commit is contained in:
parent
fcae6301ec
commit
d5834b68c3
@ -33,7 +33,6 @@ def main(args):
|
|||||||
for artifact in [
|
for artifact in [
|
||||||
"xgboost-jvm",
|
"xgboost-jvm",
|
||||||
"xgboost4j",
|
"xgboost4j",
|
||||||
"xgboost4j-gpu",
|
|
||||||
"xgboost4j-spark",
|
"xgboost4j-spark",
|
||||||
"xgboost4j-spark-gpu",
|
"xgboost4j-spark-gpu",
|
||||||
"xgboost4j-flink",
|
"xgboost4j-flink",
|
||||||
|
|||||||
@ -2,11 +2,11 @@ find_package(JNI REQUIRED)
|
|||||||
|
|
||||||
list(APPEND JVM_SOURCES
|
list(APPEND JVM_SOURCES
|
||||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
|
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
|
||||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp)
|
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cpp)
|
||||||
|
|
||||||
if(USE_CUDA)
|
if(USE_CUDA)
|
||||||
list(APPEND JVM_SOURCES
|
list(APPEND JVM_SOURCES
|
||||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu)
|
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_library(xgboost4j SHARED ${JVM_SOURCES} ${XGBOOST_OBJ_SOURCES})
|
add_library(xgboost4j SHARED ${JVM_SOURCES} ${XGBOOST_OBJ_SOURCES})
|
||||||
|
|||||||
@ -131,14 +131,6 @@ def native_build(args):
|
|||||||
run("cmake .. " + " ".join(args))
|
run("cmake .. " + " ".join(args))
|
||||||
run("cmake --build . --config Release" + maybe_parallel_build)
|
run("cmake --build . --config Release" + maybe_parallel_build)
|
||||||
|
|
||||||
with cd("demo/CLI/regression"):
|
|
||||||
run(f'"{sys.executable}" mapfeat.py')
|
|
||||||
run(f'"{sys.executable}" mknfold.py machine.txt 1')
|
|
||||||
|
|
||||||
xgboost4j = "xgboost4j-gpu" if cli_args.use_cuda == "ON" else "xgboost4j"
|
|
||||||
xgboost4j_spark = (
|
|
||||||
"xgboost4j-spark-gpu" if cli_args.use_cuda == "ON" else "xgboost4j-spark"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("copying native library")
|
print("copying native library")
|
||||||
library_name, os_folder = {
|
library_name, os_folder = {
|
||||||
@ -155,26 +147,34 @@ def native_build(args):
|
|||||||
"arm64": "aarch64", # on macOS & Windows ARM 64-bit
|
"arm64": "aarch64", # on macOS & Windows ARM 64-bit
|
||||||
"aarch64": "aarch64",
|
"aarch64": "aarch64",
|
||||||
}[platform.machine().lower()]
|
}[platform.machine().lower()]
|
||||||
output_folder = "{}/src/main/resources/lib/{}/{}".format(
|
output_folder = "xgboost4j/src/main/resources/lib/{}/{}".format(
|
||||||
xgboost4j, os_folder, arch_folder
|
os_folder, arch_folder
|
||||||
)
|
)
|
||||||
maybe_makedirs(output_folder)
|
maybe_makedirs(output_folder)
|
||||||
cp("../lib/" + library_name, output_folder)
|
cp("../lib/" + library_name, output_folder)
|
||||||
|
|
||||||
print("copying train/test files")
|
print("copying train/test files")
|
||||||
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
|
|
||||||
|
# for xgboost4j
|
||||||
|
maybe_makedirs("xgboost4j/src/test/resources")
|
||||||
|
for file in glob.glob("../demo/data/agaricus.*"):
|
||||||
|
cp(file, "xgboost4j/src/test/resources")
|
||||||
|
|
||||||
|
# for xgboost4j-spark
|
||||||
|
maybe_makedirs("xgboost4j-spark/src/test/resources")
|
||||||
with cd("../demo/CLI/regression"):
|
with cd("../demo/CLI/regression"):
|
||||||
run(f'"{sys.executable}" mapfeat.py')
|
run(f'"{sys.executable}" mapfeat.py')
|
||||||
run(f'"{sys.executable}" mknfold.py machine.txt 1')
|
run(f'"{sys.executable}" mknfold.py machine.txt 1')
|
||||||
|
|
||||||
for file in glob.glob("../demo/CLI/regression/machine.txt.t*"):
|
for file in glob.glob("../demo/CLI/regression/machine.txt.t*"):
|
||||||
cp(file, "{}/src/test/resources".format(xgboost4j_spark))
|
cp(file, "xgboost4j-spark/src/test/resources")
|
||||||
for file in glob.glob("../demo/data/agaricus.*"):
|
for file in glob.glob("../demo/data/agaricus.*"):
|
||||||
cp(file, "{}/src/test/resources".format(xgboost4j_spark))
|
cp(file, "xgboost4j-spark/src/test/resources")
|
||||||
|
|
||||||
maybe_makedirs("{}/src/test/resources".format(xgboost4j))
|
# for xgboost4j-spark-gpu
|
||||||
for file in glob.glob("../demo/data/agaricus.*"):
|
if cli_args.use_cuda == "ON":
|
||||||
cp(file, "{}/src/test/resources".format(xgboost4j))
|
maybe_makedirs("xgboost4j-spark-gpu/src/test/resources")
|
||||||
|
for file in glob.glob("../demo/data/veterans_lung_cancer.csv"):
|
||||||
|
cp(file, "xgboost4j-spark-gpu/src/test/resources")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -37,7 +37,7 @@
|
|||||||
<junit.version>4.13.2</junit.version>
|
<junit.version>4.13.2</junit.version>
|
||||||
<spark.version>3.5.1</spark.version>
|
<spark.version>3.5.1</spark.version>
|
||||||
<spark.version.gpu>3.5.1</spark.version.gpu>
|
<spark.version.gpu>3.5.1</spark.version.gpu>
|
||||||
<fasterxml.jackson.version>2.17.2</fasterxml.jackson.version>
|
<fasterxml.jackson.version>2.15.0</fasterxml.jackson.version>
|
||||||
<scala.version>2.12.18</scala.version>
|
<scala.version>2.12.18</scala.version>
|
||||||
<scala.binary.version>2.12</scala.binary.version>
|
<scala.binary.version>2.12</scala.binary.version>
|
||||||
<hadoop.version>3.4.0</hadoop.version>
|
<hadoop.version>3.4.0</hadoop.version>
|
||||||
@ -105,7 +105,7 @@
|
|||||||
<use.cuda>ON</use.cuda>
|
<use.cuda>ON</use.cuda>
|
||||||
</properties>
|
</properties>
|
||||||
<modules>
|
<modules>
|
||||||
<module>xgboost4j-gpu</module>
|
<module>xgboost4j</module>
|
||||||
<module>xgboost4j-spark-gpu</module>
|
<module>xgboost4j-spark-gpu</module>
|
||||||
</modules>
|
</modules>
|
||||||
</profile>
|
</profile>
|
||||||
@ -117,7 +117,6 @@
|
|||||||
<module>xgboost4j-example</module>
|
<module>xgboost4j-example</module>
|
||||||
<module>xgboost4j-spark</module>
|
<module>xgboost4j-spark</module>
|
||||||
<module>xgboost4j-flink</module>
|
<module>xgboost4j-flink</module>
|
||||||
<module>xgboost4j-gpu</module>
|
|
||||||
<module>xgboost4j-spark-gpu</module>
|
<module>xgboost4j-spark-gpu</module>
|
||||||
</modules>
|
</modules>
|
||||||
<build>
|
<build>
|
||||||
@ -243,7 +242,6 @@
|
|||||||
<module>xgboost4j-example</module>
|
<module>xgboost4j-example</module>
|
||||||
<module>xgboost4j-spark</module>
|
<module>xgboost4j-spark</module>
|
||||||
<module>xgboost4j-flink</module>
|
<module>xgboost4j-flink</module>
|
||||||
<module>xgboost4j-gpu</module>
|
|
||||||
<module>xgboost4j-spark-gpu</module>
|
<module>xgboost4j-spark-gpu</module>
|
||||||
</modules>
|
</modules>
|
||||||
<build>
|
<build>
|
||||||
|
|||||||
@ -1,140 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
<parent>
|
|
||||||
<groupId>ml.dmlc</groupId>
|
|
||||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
|
||||||
<version>2.2.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
<artifactId>xgboost4j-gpu_2.12</artifactId>
|
|
||||||
<name>xgboost4j-gpu</name>
|
|
||||||
<version>2.2.0-SNAPSHOT</version>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.scala-lang</groupId>
|
|
||||||
<artifactId>scala-compiler</artifactId>
|
|
||||||
<version>${scala.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.scala-lang</groupId>
|
|
||||||
<artifactId>scala-library</artifactId>
|
|
||||||
<version>${scala.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.scala-lang.modules</groupId>
|
|
||||||
<artifactId>scala-collection-compat_${scala.binary.version}</artifactId>
|
|
||||||
<version>${scala-collection-compat.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>ai.rapids</groupId>
|
|
||||||
<artifactId>cudf</artifactId>
|
|
||||||
<version>${cudf.version}</version>
|
|
||||||
<classifier>${cudf.classifier}</classifier>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.hadoop</groupId>
|
|
||||||
<artifactId>hadoop-hdfs</artifactId>
|
|
||||||
<version>${hadoop.version}</version>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.hadoop</groupId>
|
|
||||||
<artifactId>hadoop-common</artifactId>
|
|
||||||
<version>${hadoop.version}</version>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>junit</groupId>
|
|
||||||
<artifactId>junit</artifactId>
|
|
||||||
<version>${junit.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.scalatest</groupId>
|
|
||||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
|
||||||
<version>${scalatest.version}</version>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-lang3</artifactId>
|
|
||||||
<version>3.14.0</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-javadoc-plugin</artifactId>
|
|
||||||
<version>3.7.0</version>
|
|
||||||
<configuration>
|
|
||||||
<show>protected</show>
|
|
||||||
<nohelp>true</nohelp>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-assembly-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<skipAssembly>false</skipAssembly>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>exec-maven-plugin</artifactId>
|
|
||||||
<groupId>org.codehaus.mojo</groupId>
|
|
||||||
<version>3.3.0</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>native</id>
|
|
||||||
<phase>generate-sources</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>exec</goal>
|
|
||||||
</goals>
|
|
||||||
<configuration>
|
|
||||||
<executable>python</executable>
|
|
||||||
<arguments>
|
|
||||||
<argument>create_jni.py</argument>
|
|
||||||
<argument>--log-capi-invocation</argument>
|
|
||||||
<argument>${log.capi.invocation}</argument>
|
|
||||||
<argument>--use-cuda</argument>
|
|
||||||
<argument>${use.cuda}</argument>
|
|
||||||
</arguments>
|
|
||||||
<workingDirectory>${user.dir}</workingDirectory>
|
|
||||||
<skip>${skip.native.build}</skip>
|
|
||||||
</configuration>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
|
||||||
<version>3.4.1</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<goals>
|
|
||||||
<goal>test-jar</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-resources-plugin</artifactId>
|
|
||||||
<version>3.3.1</version>
|
|
||||||
<configuration>
|
|
||||||
<nonFilteredFileExtensions>
|
|
||||||
<nonFilteredFileExtension>dll</nonFilteredFileExtension>
|
|
||||||
<nonFilteredFileExtension>dylib</nonFilteredFileExtension>
|
|
||||||
<nonFilteredFileExtension>so</nonFilteredFileExtension>
|
|
||||||
</nonFilteredFileExtensions>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</project>
|
|
||||||
@ -1,110 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.gpu.java;
|
|
||||||
|
|
||||||
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
|
|
||||||
import ai.rapids.cudf.BufferType;
|
|
||||||
import ai.rapids.cudf.ColumnVector;
|
|
||||||
import ai.rapids.cudf.DType;
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Column;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
|
|
||||||
* It will be used to generate the cuda array interface.
|
|
||||||
*/
|
|
||||||
public class CudfColumn extends Column {
|
|
||||||
|
|
||||||
private final long dataPtr; // gpu data buffer address
|
|
||||||
private final long shape; // row count
|
|
||||||
private final long validPtr; // gpu valid buffer address
|
|
||||||
private final int typeSize; // type size in bytes
|
|
||||||
private final String typeStr; // follow array interface spec
|
|
||||||
private final long nullCount; // null count
|
|
||||||
|
|
||||||
private String arrayInterface = null; // the cuda array interface
|
|
||||||
|
|
||||||
public static CudfColumn from(ColumnVector cv) {
|
|
||||||
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
|
|
||||||
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
|
|
||||||
long validPtr = 0;
|
|
||||||
if (validBuffer != null) {
|
|
||||||
validPtr = validBuffer.getAddress();
|
|
||||||
}
|
|
||||||
DType dType = cv.getType();
|
|
||||||
String typeStr = "";
|
|
||||||
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
|
|
||||||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
|
|
||||||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
|
|
||||||
dType == DType.TIMESTAMP_SECONDS) {
|
|
||||||
typeStr = "<f" + dType.getSizeInBytes();
|
|
||||||
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
|
|
||||||
dType == DType.INT32 || dType == DType.INT64) {
|
|
||||||
typeStr = "<i" + dType.getSizeInBytes();
|
|
||||||
} else {
|
|
||||||
// Unsupported type.
|
|
||||||
throw new IllegalArgumentException("Unsupported data type: " + dType);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
|
|
||||||
dType.getSizeInBytes(), typeStr, cv.getNullCount());
|
|
||||||
}
|
|
||||||
|
|
||||||
private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
|
|
||||||
long nullCount) {
|
|
||||||
this.dataPtr = dataPtr;
|
|
||||||
this.shape = shape;
|
|
||||||
this.validPtr = validPtr;
|
|
||||||
this.typeSize = typeSize;
|
|
||||||
this.typeStr = typeStr;
|
|
||||||
this.nullCount = nullCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getArrayInterfaceJson() {
|
|
||||||
// There is no race-condition
|
|
||||||
if (arrayInterface == null) {
|
|
||||||
arrayInterface = CudfUtils.buildArrayInterface(this);
|
|
||||||
}
|
|
||||||
return arrayInterface;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getDataPtr() {
|
|
||||||
return dataPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getShape() {
|
|
||||||
return shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getValidPtr() {
|
|
||||||
return validPtr;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getTypeSize() {
|
|
||||||
return typeSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getTypeStr() {
|
|
||||||
return typeStr;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getNullCount() {
|
|
||||||
return nullCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1,88 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.gpu.java;
|
|
||||||
|
|
||||||
import java.util.stream.IntStream;
|
|
||||||
|
|
||||||
import ai.rapids.cudf.Table;
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Class to wrap CUDF Table to generate the cuda array interface.
|
|
||||||
*/
|
|
||||||
public class CudfColumnBatch extends ColumnBatch {
|
|
||||||
private final Table feature;
|
|
||||||
private final Table label;
|
|
||||||
private final Table weight;
|
|
||||||
private final Table baseMargin;
|
|
||||||
|
|
||||||
public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
|
|
||||||
this.feature = feature;
|
|
||||||
this.label = labels;
|
|
||||||
this.weight = weights;
|
|
||||||
this.baseMargin = baseMargins;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getFeatureArrayInterface() {
|
|
||||||
return getArrayInterface(this.feature);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getLabelsArrayInterface() {
|
|
||||||
return getArrayInterface(this.label);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getWeightsArrayInterface() {
|
|
||||||
return getArrayInterface(this.weight);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getBaseMarginsArrayInterface() {
|
|
||||||
return getArrayInterface(this.baseMargin);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() {
|
|
||||||
if (feature != null) feature.close();
|
|
||||||
if (label != null) label.close();
|
|
||||||
if (weight != null) weight.close();
|
|
||||||
if (baseMargin != null) baseMargin.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getArrayInterface(Table table) {
|
|
||||||
if (table == null || table.getNumberOfColumns() == 0) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
|
|
||||||
}
|
|
||||||
|
|
||||||
private CudfColumn[] getAsCudfColumn(Table table) {
|
|
||||||
if (table == null || table.getNumberOfColumns() == 0) {
|
|
||||||
// This will never happen.
|
|
||||||
return new CudfColumn[]{};
|
|
||||||
}
|
|
||||||
|
|
||||||
return IntStream.range(0, table.getNumberOfColumns())
|
|
||||||
.mapToObj((i) -> table.getColumn(i))
|
|
||||||
.map(CudfColumn::from)
|
|
||||||
.toArray(CudfColumn[]::new);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1,98 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2021-2022 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.gpu.java;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cudf utilities to build cuda array interface against {@link CudfColumn}
|
|
||||||
*/
|
|
||||||
class CudfUtils {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build the cuda array interface based on CudfColumn(s)
|
|
||||||
* @param cudfColumns the CudfColumn(s) to be built
|
|
||||||
* @return the json format of cuda array interface
|
|
||||||
*/
|
|
||||||
public static String buildArrayInterface(CudfColumn... cudfColumns) {
|
|
||||||
return new Builder().add(cudfColumns).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper class to build array interface string
|
|
||||||
private static class Builder {
|
|
||||||
private ArrayList<String> colArrayInterfaces = new ArrayList<String>();
|
|
||||||
|
|
||||||
private Builder add(CudfColumn... columns) {
|
|
||||||
if (columns == null || columns.length <= 0) {
|
|
||||||
throw new IllegalArgumentException("At least one ColumnData is required.");
|
|
||||||
}
|
|
||||||
for (CudfColumn cd : columns) {
|
|
||||||
colArrayInterfaces.add(buildColumnObject(cd));
|
|
||||||
}
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
private String build() {
|
|
||||||
StringBuilder builder = new StringBuilder();
|
|
||||||
builder.append("[");
|
|
||||||
for (int i = 0; i < colArrayInterfaces.size(); i++) {
|
|
||||||
builder.append(colArrayInterfaces.get(i));
|
|
||||||
if (i != colArrayInterfaces.size() - 1) {
|
|
||||||
builder.append(",");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
builder.append("]");
|
|
||||||
return builder.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** build the whole column information including data and valid info */
|
|
||||||
private String buildColumnObject(CudfColumn column) {
|
|
||||||
if (column.getDataPtr() == 0) {
|
|
||||||
throw new IllegalArgumentException("Empty column data is NOT accepted!");
|
|
||||||
}
|
|
||||||
if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) {
|
|
||||||
throw new IllegalArgumentException("Empty type string is NOT accepted!");
|
|
||||||
}
|
|
||||||
|
|
||||||
StringBuilder builder = new StringBuilder();
|
|
||||||
String colData = buildMetaObject(column.getDataPtr(), column.getShape(),
|
|
||||||
column.getTypeStr());
|
|
||||||
builder.append("{");
|
|
||||||
builder.append(colData);
|
|
||||||
if (column.getValidPtr() != 0 && column.getNullCount() != 0) {
|
|
||||||
String validString = buildMetaObject(column.getValidPtr(), column.getShape(), "<t1");
|
|
||||||
builder.append(",\"mask\":");
|
|
||||||
builder.append("{");
|
|
||||||
builder.append(validString);
|
|
||||||
builder.append("}");
|
|
||||||
}
|
|
||||||
builder.append("}");
|
|
||||||
return builder.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** build the base information of a column */
|
|
||||||
private String buildMetaObject(long ptr, long shape, final String typeStr) {
|
|
||||||
StringBuilder builder = new StringBuilder();
|
|
||||||
builder.append("\"shape\":[" + shape + "],");
|
|
||||||
builder.append("\"data\":[" + ptr + "," + "false" + "],");
|
|
||||||
builder.append("\"typestr\":\"" + typeStr + "\",");
|
|
||||||
builder.append("\"version\":" + 1);
|
|
||||||
return builder.toString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@ -1 +0,0 @@
|
|||||||
../../../../../../../xgboost4j/src/main/java/ml/dmlc/xgboost4j/java
|
|
||||||
@ -1 +0,0 @@
|
|||||||
../../../../xgboost4j/src/main/resources/xgboost4j-version.properties
|
|
||||||
@ -1 +0,0 @@
|
|||||||
../../../xgboost4j/src/main/scala/
|
|
||||||
@ -24,7 +24,7 @@
|
|||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j-gpu_2.12</artifactId>
|
<artifactId>xgboost4j_2.12</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
@ -51,5 +51,17 @@
|
|||||||
<version>${spark.rapids.version}</version>
|
<version>${spark.rapids.version}</version>
|
||||||
<scope>provided</scope>
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
|
<artifactId>jackson-databind</artifactId>
|
||||||
|
<version>${fasterxml.jackson.version}</version>
|
||||||
|
<scope>provided</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<version>${junit.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@ -0,0 +1,117 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
|
||||||
|
import ai.rapids.cudf.ColumnVector;
|
||||||
|
import ai.rapids.cudf.DType;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CudfColumn is the CUDF column representing, providing the cuda array interface
|
||||||
|
*/
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
public class CudfColumn extends Column {
|
||||||
|
private List<Long> shape = new ArrayList<>(); // row count
|
||||||
|
private List<Object> data = new ArrayList<>(); // gpu data buffer address
|
||||||
|
private String typestr;
|
||||||
|
private int version = 1;
|
||||||
|
private CudfColumn mask = null;
|
||||||
|
|
||||||
|
public CudfColumn(long shape, long data, String typestr, int version) {
|
||||||
|
this.shape.add(shape);
|
||||||
|
this.data.add(data);
|
||||||
|
this.data.add(false);
|
||||||
|
this.typestr = typestr;
|
||||||
|
this.version = version;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create CudfColumn according to ColumnVector
|
||||||
|
*/
|
||||||
|
public static CudfColumn from(ColumnVector cv) {
|
||||||
|
BaseDeviceMemoryBuffer dataBuffer = cv.getData();
|
||||||
|
assert dataBuffer != null;
|
||||||
|
|
||||||
|
DType dType = cv.getType();
|
||||||
|
String typeStr = "";
|
||||||
|
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
|
||||||
|
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
|
||||||
|
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
|
||||||
|
dType == DType.TIMESTAMP_SECONDS) {
|
||||||
|
typeStr = "<f" + dType.getSizeInBytes();
|
||||||
|
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
|
||||||
|
dType == DType.INT32 || dType == DType.INT64) {
|
||||||
|
typeStr = "<i" + dType.getSizeInBytes();
|
||||||
|
} else {
|
||||||
|
// Unsupported type.
|
||||||
|
throw new IllegalArgumentException("Unsupported data type: " + dType);
|
||||||
|
}
|
||||||
|
|
||||||
|
CudfColumn data = new CudfColumn(cv.getRowCount(), dataBuffer.getAddress(), typeStr, 1);
|
||||||
|
|
||||||
|
BaseDeviceMemoryBuffer validBuffer = cv.getValid();
|
||||||
|
if (validBuffer != null && cv.getNullCount() != 0) {
|
||||||
|
CudfColumn mask = new CudfColumn(cv.getRowCount(), validBuffer.getAddress(), "<t1", 1);
|
||||||
|
data.setMask(mask);
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Long> getShape() {
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Object> getData() {
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTypestr() {
|
||||||
|
return typestr;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getVersion() {
|
||||||
|
return version;
|
||||||
|
}
|
||||||
|
|
||||||
|
public CudfColumn getMask() {
|
||||||
|
return mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMask(CudfColumn mask) {
|
||||||
|
this.mask = mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toJson() {
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
|
||||||
|
try {
|
||||||
|
List<CudfColumn> objects = new ArrayList<>(1);
|
||||||
|
objects.add(this);
|
||||||
|
return mapper.writeValueAsString(objects);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,137 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
import ai.rapids.cudf.Table;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CudfColumnBatch wraps multiple CudfColumns to provide the cuda
|
||||||
|
* array interface json string for all columns.
|
||||||
|
*/
|
||||||
|
public class CudfColumnBatch extends ColumnBatch {
|
||||||
|
@JsonIgnore
|
||||||
|
private final Table featureTable;
|
||||||
|
@JsonIgnore
|
||||||
|
private final Table labelTable;
|
||||||
|
@JsonIgnore
|
||||||
|
private final Table weightTable;
|
||||||
|
@JsonIgnore
|
||||||
|
private final Table baseMarginTable;
|
||||||
|
@JsonIgnore
|
||||||
|
private final Table qidTable;
|
||||||
|
|
||||||
|
private List<CudfColumn> features;
|
||||||
|
private List<CudfColumn> label;
|
||||||
|
private List<CudfColumn> weight;
|
||||||
|
private List<CudfColumn> baseMargin;
|
||||||
|
private List<CudfColumn> qid;
|
||||||
|
|
||||||
|
public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
|
||||||
|
Table baseMarginTable, Table qidTable) {
|
||||||
|
this.featureTable = featureTable;
|
||||||
|
this.labelTable = labelTable;
|
||||||
|
this.weightTable = weightTable;
|
||||||
|
this.baseMarginTable = baseMarginTable;
|
||||||
|
this.qidTable = qidTable;
|
||||||
|
|
||||||
|
features = initializeCudfColumns(featureTable);
|
||||||
|
if (labelTable != null) {
|
||||||
|
assert labelTable.getNumberOfColumns() == 1;
|
||||||
|
label = initializeCudfColumns(labelTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weightTable != null) {
|
||||||
|
assert weightTable.getNumberOfColumns() == 1;
|
||||||
|
weight = initializeCudfColumns(weightTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (baseMarginTable != null) {
|
||||||
|
baseMargin = initializeCudfColumns(baseMarginTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (qidTable != null) {
|
||||||
|
qid = initializeCudfColumns(qidTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<CudfColumn> initializeCudfColumns(Table table) {
|
||||||
|
assert table != null && table.getNumberOfColumns() > 0;
|
||||||
|
|
||||||
|
return IntStream.range(0, table.getNumberOfColumns())
|
||||||
|
.mapToObj(table::getColumn)
|
||||||
|
.map(CudfColumn::from)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<CudfColumn> getFeatures() {
|
||||||
|
return features;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<CudfColumn> getLabel() {
|
||||||
|
return label;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<CudfColumn> getWeight() {
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<CudfColumn> getBaseMargin() {
|
||||||
|
return baseMargin;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<CudfColumn> getQid() {
|
||||||
|
return qid;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String toJson() {
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
|
||||||
|
try {
|
||||||
|
return mapper.writeValueAsString(this);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toFeaturesJson() {
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
try {
|
||||||
|
return mapper.writeValueAsString(features);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
if (featureTable != null) featureTable.close();
|
||||||
|
if (labelTable != null) labelTable.close();
|
||||||
|
if (weightTable != null) weightTable.close();
|
||||||
|
if (baseMarginTable != null) baseMarginTable.close();
|
||||||
|
if (qidTable != null) qidTable.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,90 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* QuantileDMatrix will only be used to train
|
||||||
|
*/
|
||||||
|
public class QuantileDMatrix extends DMatrix {
|
||||||
|
/**
|
||||||
|
* Create QuantileDMatrix from iterator based on the cuda array interface
|
||||||
|
*
|
||||||
|
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||||
|
* @param missing the missing value
|
||||||
|
* @param maxBin the max bin
|
||||||
|
* @param nthread the parallelism
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public QuantileDMatrix(
|
||||||
|
Iterator<ColumnBatch> iter,
|
||||||
|
float missing,
|
||||||
|
int maxBin,
|
||||||
|
int nthread) throws XGBoostError {
|
||||||
|
super(0);
|
||||||
|
long[] out = new long[1];
|
||||||
|
String conf = getConfig(missing, maxBin, nthread);
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
|
||||||
|
iter, null, conf, out));
|
||||||
|
handle = out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setLabel(Column column) throws XGBoostError {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setWeight(Column column) throws XGBoostError {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setBaseMargin(Column column) throws XGBoostError {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setLabel(float[] labels) throws XGBoostError {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setWeight(float[] weights) throws XGBoostError {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setGroup(int[] group) throws XGBoostError {
|
||||||
|
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getConfig(float missing, int maxBin, int nthread) {
|
||||||
|
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
|
||||||
|
missing, maxBin, nthread);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -17,14 +17,12 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.rapids.spark
|
package ml.dmlc.xgboost4j.scala.rapids.spark
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
|
||||||
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
|
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
|
||||||
|
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||||
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
|
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
|
||||||
import org.apache.spark.{SparkContext, TaskContext}
|
import org.apache.spark.{SparkContext, TaskContext}
|
||||||
import org.apache.spark.ml.{Estimator, Model}
|
import org.apache.spark.ml.{Estimator, Model}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
@ -325,7 +323,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
throw new RuntimeException("Something wrong for feature indices")
|
throw new RuntimeException("Something wrong for feature indices")
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null)
|
val cudfColumnBatch = new CudfColumnBatch(feaTable, null, null, null, null)
|
||||||
val dm = new DMatrix(cudfColumnBatch, missing, 1)
|
val dm = new DMatrix(cudfColumnBatch, missing, 1)
|
||||||
if (dm == null) {
|
if (dm == null) {
|
||||||
Iterator.empty
|
Iterator.empty
|
||||||
@ -586,7 +584,8 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
|||||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(indices.featureIds).asJava),
|
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(indices.featureIds).asJava),
|
||||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(Seq(indices.labelId)).asJava),
|
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(Seq(indices.labelId)).asJava),
|
||||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(weights).asJava),
|
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(weights).asJava),
|
||||||
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava));
|
gpuColumnBatch.slice(GpuUtils.seqIntToSeqInteger(margins).asJava),
|
||||||
|
null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -14,7 +14,7 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.gpu.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@ -22,31 +22,21 @@ import java.util.LinkedList;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
import ai.rapids.cudf.*;
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import ai.rapids.cudf.DType;
|
|
||||||
import ai.rapids.cudf.Schema;
|
|
||||||
import ai.rapids.cudf.Table;
|
|
||||||
import ai.rapids.cudf.ColumnVector;
|
|
||||||
import ai.rapids.cudf.CSVOptions;
|
|
||||||
import ml.dmlc.xgboost4j.java.Booster;
|
|
||||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
|
||||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
|
||||||
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests the BoosterTest trained by DMatrix
|
* Tests the BoosterTest trained by DMatrix
|
||||||
|
*
|
||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
*/
|
*/
|
||||||
public class BoosterTest {
|
public class BoosterTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBooster() throws XGBoostError {
|
public void testBooster() throws XGBoostError {
|
||||||
String trainingDataPath = "../../demo/data/veterans_lung_cancer.csv";
|
String trainingDataPath = getClass().getClassLoader()
|
||||||
|
.getResource("veterans_lung_cancer.csv").getPath();
|
||||||
Schema schema = Schema.builder()
|
Schema schema = Schema.builder()
|
||||||
.column(DType.FLOAT32, "A")
|
.column(DType.FLOAT32, "A")
|
||||||
.column(DType.FLOAT32, "B")
|
.column(DType.FLOAT32, "B")
|
||||||
@ -95,7 +85,7 @@ public class BoosterTest {
|
|||||||
|
|
||||||
try (Table y = new Table(labels);) {
|
try (Table y = new Table(labels);) {
|
||||||
|
|
||||||
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null);
|
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null, null);
|
||||||
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
|
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
|
||||||
|
|
||||||
//set watchList
|
//set watchList
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021-2022 by Contributors
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -14,24 +14,15 @@
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.gpu.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.LinkedList;
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
|
||||||
|
|
||||||
import com.google.common.primitives.Floats;
|
|
||||||
|
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import ai.rapids.cudf.Table;
|
import ai.rapids.cudf.Table;
|
||||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
import junit.framework.TestCase;
|
||||||
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
|
import org.junit.Test;
|
||||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
|
||||||
@ -43,24 +34,29 @@ public class DMatrixTest {
|
|||||||
@Test
|
@Test
|
||||||
public void testCreateFromArrayInterfaceColumns() {
|
public void testCreateFromArrayInterfaceColumns() {
|
||||||
Float[] labelFloats = new Float[]{2f, 4f, 6f, 8f, 10f};
|
Float[] labelFloats = new Float[]{2f, 4f, 6f, 8f, 10f};
|
||||||
|
Integer[] groups = new Integer[]{1, 1, 7, 7, 19, 26};
|
||||||
|
int[] expectedGroup = new int[]{0, 2, 4, 5, 6};
|
||||||
|
|
||||||
Throwable ex = null;
|
Throwable ex = null;
|
||||||
try (
|
try (
|
||||||
Table X = new Table.TestBuilder().column(1.f, null, 5.f, 7.f, 9.f).build();
|
Table X = new Table.TestBuilder().column(1.f, null, 5.f, 7.f, 9.f).build();
|
||||||
Table y = new Table.TestBuilder().column(labelFloats).build();
|
Table y = new Table.TestBuilder().column(labelFloats).build();
|
||||||
Table w = new Table.TestBuilder().column(labelFloats).build();
|
Table w = new Table.TestBuilder().column(labelFloats).build();
|
||||||
|
Table q = new Table.TestBuilder().column(groups).build();
|
||||||
Table margin = new Table.TestBuilder().column(labelFloats).build();) {
|
Table margin = new Table.TestBuilder().column(labelFloats).build();) {
|
||||||
|
|
||||||
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null);
|
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null, null);
|
||||||
|
|
||||||
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
|
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
|
||||||
CudfColumn weightColumn = CudfColumn.from(w.getColumn(0));
|
CudfColumn weightColumn = CudfColumn.from(w.getColumn(0));
|
||||||
CudfColumn baseMarginColumn = CudfColumn.from(margin.getColumn(0));
|
CudfColumn baseMarginColumn = CudfColumn.from(margin.getColumn(0));
|
||||||
|
CudfColumn qidColumn = CudfColumn.from(q.getColumn(0));
|
||||||
|
|
||||||
DMatrix dMatrix = new DMatrix(cudfDataFrame, 0, 1);
|
DMatrix dMatrix = new DMatrix(cudfDataFrame, 0, 1);
|
||||||
dMatrix.setLabel(labelColumn);
|
dMatrix.setLabel(labelColumn);
|
||||||
dMatrix.setWeight(weightColumn);
|
dMatrix.setWeight(weightColumn);
|
||||||
dMatrix.setBaseMargin(baseMarginColumn);
|
dMatrix.setBaseMargin(baseMarginColumn);
|
||||||
|
dMatrix.setQueryId(qidColumn);
|
||||||
|
|
||||||
String[] featureNames = new String[]{"f1"};
|
String[] featureNames = new String[]{"f1"};
|
||||||
dMatrix.setFeatureNames(featureNames);
|
dMatrix.setFeatureNames(featureNames);
|
||||||
@ -76,10 +72,12 @@ public class DMatrixTest {
|
|||||||
float[] label = dMatrix.getLabel();
|
float[] label = dMatrix.getLabel();
|
||||||
float[] weight = dMatrix.getWeight();
|
float[] weight = dMatrix.getWeight();
|
||||||
float[] baseMargin = dMatrix.getBaseMargin();
|
float[] baseMargin = dMatrix.getBaseMargin();
|
||||||
|
int[] group = dMatrix.getGroup();
|
||||||
|
|
||||||
TestCase.assertTrue(Arrays.equals(anchor, label));
|
TestCase.assertTrue(Arrays.equals(anchor, label));
|
||||||
TestCase.assertTrue(Arrays.equals(anchor, weight));
|
TestCase.assertTrue(Arrays.equals(anchor, weight));
|
||||||
TestCase.assertTrue(Arrays.equals(anchor, baseMargin));
|
TestCase.assertTrue(Arrays.equals(anchor, baseMargin));
|
||||||
|
TestCase.assertTrue(Arrays.equals(expectedGroup, group));
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
ex = e;
|
ex = e;
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
@ -93,10 +91,14 @@ public class DMatrixTest {
|
|||||||
Float[] label1 = {25f, 21f, 22f, 20f, 24f};
|
Float[] label1 = {25f, 21f, 22f, 20f, 24f};
|
||||||
Float[] weight1 = {1.3f, 2.31f, 0.32f, 3.3f, 1.34f};
|
Float[] weight1 = {1.3f, 2.31f, 0.32f, 3.3f, 1.34f};
|
||||||
Float[] baseMargin1 = {1.2f, 0.2f, 1.3f, 2.4f, 3.5f};
|
Float[] baseMargin1 = {1.2f, 0.2f, 1.3f, 2.4f, 3.5f};
|
||||||
|
Integer[] groups1 = new Integer[]{1, 1, 7, 7, 19, 26};
|
||||||
|
|
||||||
Float[] label2 = {9f, 5f, 4f, 10f, 12f};
|
Float[] label2 = {9f, 5f, 4f, 10f, 12f};
|
||||||
Float[] weight2 = {3.0f, 1.3f, 3.2f, 0.3f, 1.34f};
|
Float[] weight2 = {3.0f, 1.3f, 3.2f, 0.3f, 1.34f};
|
||||||
Float[] baseMargin2 = {0.2f, 2.5f, 3.1f, 4.4f, 2.2f};
|
Float[] baseMargin2 = {0.2f, 2.5f, 3.1f, 4.4f, 2.2f};
|
||||||
|
Integer[] groups2 = new Integer[]{30, 30, 30, 40, 40};
|
||||||
|
|
||||||
|
int[] expectedGroup = new int[]{0, 2, 4, 5, 6, 9, 11};
|
||||||
|
|
||||||
try (
|
try (
|
||||||
Table X_0 = new Table.TestBuilder()
|
Table X_0 = new Table.TestBuilder()
|
||||||
@ -106,30 +108,47 @@ public class DMatrixTest {
|
|||||||
Table y_0 = new Table.TestBuilder().column(label1).build();
|
Table y_0 = new Table.TestBuilder().column(label1).build();
|
||||||
Table w_0 = new Table.TestBuilder().column(weight1).build();
|
Table w_0 = new Table.TestBuilder().column(weight1).build();
|
||||||
Table m_0 = new Table.TestBuilder().column(baseMargin1).build();
|
Table m_0 = new Table.TestBuilder().column(baseMargin1).build();
|
||||||
|
Table q_0 = new Table.TestBuilder().column(groups1).build();
|
||||||
|
|
||||||
Table X_1 = new Table.TestBuilder().column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f)
|
Table X_1 = new Table.TestBuilder().column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f)
|
||||||
.column(1.2f, 1.4f, null, 12.6f, 10.10f).build();
|
.column(1.2f, 1.4f, null, 12.6f, 10.10f).build();
|
||||||
Table y_1 = new Table.TestBuilder().column(label2).build();
|
Table y_1 = new Table.TestBuilder().column(label2).build();
|
||||||
Table w_1 = new Table.TestBuilder().column(weight2).build();
|
Table w_1 = new Table.TestBuilder().column(weight2).build();
|
||||||
Table m_1 = new Table.TestBuilder().column(baseMargin2).build();) {
|
Table m_1 = new Table.TestBuilder().column(baseMargin2).build();) {
|
||||||
|
Table q_1 = new Table.TestBuilder().column(groups2).build();
|
||||||
|
|
||||||
List<ColumnBatch> tables = new LinkedList<>();
|
List<ColumnBatch> tables = new LinkedList<>();
|
||||||
|
|
||||||
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
|
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0));
|
||||||
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
|
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1, q_1));
|
||||||
|
|
||||||
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
|
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 256, 1);
|
||||||
|
|
||||||
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
|
float[] anchorLabel = convertFloatTofloat(label1, label2);
|
||||||
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
|
float[] anchorWeight = convertFloatTofloat(weight1, weight2);
|
||||||
float[] anchorBaseMargin = convertFloatTofloat((Float[]) ArrayUtils.addAll(baseMargin1, baseMargin2));
|
float[] anchorBaseMargin = convertFloatTofloat(baseMargin1, baseMargin2);
|
||||||
|
|
||||||
TestCase.assertTrue(Arrays.equals(anchorLabel, dmat.getLabel()));
|
TestCase.assertTrue(Arrays.equals(anchorLabel, dmat.getLabel()));
|
||||||
TestCase.assertTrue(Arrays.equals(anchorWeight, dmat.getWeight()));
|
TestCase.assertTrue(Arrays.equals(anchorWeight, dmat.getWeight()));
|
||||||
TestCase.assertTrue(Arrays.equals(anchorBaseMargin, dmat.getBaseMargin()));
|
TestCase.assertTrue(Arrays.equals(anchorBaseMargin, dmat.getBaseMargin()));
|
||||||
|
TestCase.assertTrue(Arrays.equals(expectedGroup, dmat.getGroup()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private float[] convertFloatTofloat(Float[] in) {
|
private float[] convertFloatTofloat(Float[]... datas) {
|
||||||
return Floats.toArray(Arrays.asList(in));
|
int totalLength = 0;
|
||||||
|
for (Float[] data : datas) {
|
||||||
|
totalLength += data.length;
|
||||||
}
|
}
|
||||||
|
float[] floatArray = new float[totalLength];
|
||||||
|
int index = 0;
|
||||||
|
for (Float[] data : datas) {
|
||||||
|
for (int i = 0; i < data.length; i++) {
|
||||||
|
floatArray[i + index] = data[i];
|
||||||
|
}
|
||||||
|
index += data.length;
|
||||||
|
}
|
||||||
|
return floatArray;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,11 +16,11 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
|
||||||
|
|
||||||
import ai.rapids.cudf.Table
|
import ai.rapids.cudf.Table
|
||||||
|
import ml.dmlc.xgboost4j.java.CudfColumnBatch
|
||||||
import org.scalatest.funsuite.AnyFunSuite
|
import org.scalatest.funsuite.AnyFunSuite
|
||||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
class QuantileDMatrixSuite extends AnyFunSuite {
|
class QuantileDMatrixSuite extends AnyFunSuite {
|
||||||
|
|
||||||
@ -29,10 +29,14 @@ class QuantileDMatrixSuite extends AnyFunSuite {
|
|||||||
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
|
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
|
||||||
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
|
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
|
||||||
val baseMargin1 = Array[java.lang.Float](1.2f, 0.2f, 1.3f, 2.4f, 3.5f)
|
val baseMargin1 = Array[java.lang.Float](1.2f, 0.2f, 1.3f, 2.4f, 3.5f)
|
||||||
|
val group1 = Array[java.lang.Integer](1, 1, 7, 7, 19, 26)
|
||||||
|
|
||||||
val label2 = Array[java.lang.Float](9f, 5f, 4f, 10f, 12f)
|
val label2 = Array[java.lang.Float](9f, 5f, 4f, 10f, 12f)
|
||||||
val weight2 = Array[java.lang.Float](3.0f, 1.3f, 3.2f, 0.3f, 1.34f)
|
val weight2 = Array[java.lang.Float](3.0f, 1.3f, 3.2f, 0.3f, 1.34f)
|
||||||
val baseMargin2 = Array[java.lang.Float](0.2f, 2.5f, 3.1f, 4.4f, 2.2f)
|
val baseMargin2 = Array[java.lang.Float](0.2f, 2.5f, 3.1f, 4.4f, 2.2f)
|
||||||
|
val group2 = Array[java.lang.Integer](30, 30, 30, 40, 40)
|
||||||
|
|
||||||
|
val expectedGroup = Array(0, 2, 4, 5, 6, 9, 11)
|
||||||
|
|
||||||
withResource(new Table.TestBuilder()
|
withResource(new Table.TestBuilder()
|
||||||
.column(1.2f, null.asInstanceOf[java.lang.Float], 5.2f, 7.2f, 9.2f)
|
.column(1.2f, null.asInstanceOf[java.lang.Float], 5.2f, 7.2f, 9.2f)
|
||||||
@ -41,20 +45,25 @@ class QuantileDMatrixSuite extends AnyFunSuite {
|
|||||||
withResource(new Table.TestBuilder().column(label1: _*).build) { y_0 =>
|
withResource(new Table.TestBuilder().column(label1: _*).build) { y_0 =>
|
||||||
withResource(new Table.TestBuilder().column(weight1: _*).build) { w_0 =>
|
withResource(new Table.TestBuilder().column(weight1: _*).build) { w_0 =>
|
||||||
withResource(new Table.TestBuilder().column(baseMargin1: _*).build) { m_0 =>
|
withResource(new Table.TestBuilder().column(baseMargin1: _*).build) { m_0 =>
|
||||||
|
withResource(new Table.TestBuilder().column(group1: _*).build) { q_0 =>
|
||||||
withResource(new Table.TestBuilder()
|
withResource(new Table.TestBuilder()
|
||||||
.column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float])
|
.column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float])
|
||||||
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build)
|
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build) {
|
||||||
{ X_1 =>
|
X_1 =>
|
||||||
withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 =>
|
withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 =>
|
||||||
withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 =>
|
withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 =>
|
||||||
withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 =>
|
withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 =>
|
||||||
|
withResource(new Table.TestBuilder().column(group2: _*).build) { q_2 =>
|
||||||
val batches = new ArrayBuffer[CudfColumnBatch]()
|
val batches = new ArrayBuffer[CudfColumnBatch]()
|
||||||
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
|
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0, q_0)
|
||||||
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
|
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1, q_2)
|
||||||
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||||
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
|
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
|
||||||
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
|
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
|
||||||
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
|
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
|
||||||
|
assert(dmatrix.getGroup().sameElements(expectedGroup))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -73,6 +82,4 @@ class QuantileDMatrixSuite extends AnyFunSuite {
|
|||||||
r.close()
|
r.close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,6 +96,8 @@
|
|||||||
<argument>create_jni.py</argument>
|
<argument>create_jni.py</argument>
|
||||||
<argument>--log-capi-invocation</argument>
|
<argument>--log-capi-invocation</argument>
|
||||||
<argument>${log.capi.invocation}</argument>
|
<argument>${log.capi.invocation}</argument>
|
||||||
|
<argument>--use-cuda</argument>
|
||||||
|
<argument>${use.cuda}</argument>
|
||||||
</arguments>
|
</arguments>
|
||||||
<workingDirectory>${user.dir}</workingDirectory>
|
<workingDirectory>${user.dir}</workingDirectory>
|
||||||
<skip>${skip.native.build}</skip>
|
<skip>${skip.native.build}</skip>
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -17,24 +17,17 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The abstracted XGBoost Column to get the cuda array interface which is used to
|
* This Column abstraction provides an array interface JSON string, which is
|
||||||
* set the information for DMatrix.
|
* used to reconstruct columnar data within the XGBoost library.
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
public abstract class Column implements AutoCloseable {
|
public abstract class Column implements AutoCloseable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the cuda array interface json string for the Column which can be representing
|
* Return array interface json string for this Column
|
||||||
* weight, label, base margin column.
|
|
||||||
*
|
|
||||||
* This API will be called by
|
|
||||||
* {@link DMatrix#setLabel(Column)}
|
|
||||||
* {@link DMatrix#setWeight(Column)}
|
|
||||||
* {@link DMatrix#setBaseMargin(Column)}
|
|
||||||
*/
|
*/
|
||||||
public abstract String getArrayInterfaceJson();
|
public abstract String toJson();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws Exception {}
|
public void close() throws Exception {
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2021 by Contributors
|
Copyright (c) 2021-2024 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -16,78 +16,12 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.util.Iterator;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
|
* This class wraps multiple Column and provides the array interface json
|
||||||
* For example, the cuDF dataframe which employs apache arrow specification.
|
* for all columns.
|
||||||
*/
|
*/
|
||||||
public abstract class ColumnBatch implements AutoCloseable {
|
public abstract class ColumnBatch extends Column {
|
||||||
/**
|
|
||||||
* Get the cuda array interface json string for the whole ColumnBatch including
|
|
||||||
* the must-have feature, label columns and the optional weight, base margin columns.
|
|
||||||
*
|
|
||||||
* This function is be called by native code during iteration and can be made as private
|
|
||||||
* method. We keep it as public simply to silent the linter.
|
|
||||||
*/
|
|
||||||
public final String getArrayInterfaceJson() {
|
|
||||||
|
|
||||||
StringBuilder builder = new StringBuilder();
|
|
||||||
builder.append("{");
|
|
||||||
String featureStr = this.getFeatureArrayInterface();
|
|
||||||
if (featureStr == null || featureStr.isEmpty()) {
|
|
||||||
throw new RuntimeException("Feature array interface must not be empty");
|
|
||||||
} else {
|
|
||||||
builder.append("\"features_str\":" + featureStr);
|
|
||||||
}
|
|
||||||
|
|
||||||
String labelStr = this.getLabelsArrayInterface();
|
|
||||||
if (labelStr == null || labelStr.isEmpty()) {
|
|
||||||
throw new RuntimeException("Label array interface must not be empty");
|
|
||||||
} else {
|
|
||||||
builder.append(",\"label_str\":" + labelStr);
|
|
||||||
}
|
|
||||||
|
|
||||||
String weightStr = getWeightsArrayInterface();
|
|
||||||
if (weightStr != null && ! weightStr.isEmpty()) {
|
|
||||||
builder.append(",\"weight_str\":" + weightStr);
|
|
||||||
}
|
|
||||||
|
|
||||||
String baseMarginStr = getBaseMarginsArrayInterface();
|
|
||||||
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
|
|
||||||
builder.append(",\"basemargin_str\":" + baseMarginStr);
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.append("}");
|
|
||||||
return builder.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the cuda array interface of the feature columns.
|
|
||||||
* The returned value must not be null or empty
|
|
||||||
*/
|
|
||||||
public abstract String getFeatureArrayInterface();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the cuda array interface of the label columns.
|
|
||||||
* The returned value must not be null or empty if we're creating
|
|
||||||
* {@link QuantileDMatrix#QuantileDMatrix(Iterator, float, int, int)}
|
|
||||||
*/
|
|
||||||
public abstract String getLabelsArrayInterface();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the cuda array interface of the weight columns.
|
|
||||||
* The returned value can be null or empty
|
|
||||||
*/
|
|
||||||
public abstract String getWeightsArrayInterface();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the cuda array interface of the base margin columns.
|
|
||||||
* The returned value can be null or empty
|
|
||||||
*/
|
|
||||||
public abstract String getBaseMarginsArrayInterface();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() throws Exception {}
|
|
||||||
|
|
||||||
|
/** Get features cuda array interface json string */
|
||||||
|
public abstract String toFeaturesJson();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -195,7 +195,7 @@ public class DMatrix {
|
|||||||
*/
|
*/
|
||||||
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
|
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
|
||||||
long[] out = new long[1];
|
long[] out = new long[1];
|
||||||
String json = columnBatch.getFeatureArrayInterface();
|
String json = columnBatch.toFeaturesJson();
|
||||||
if (json == null || json.isEmpty()) {
|
if (json == null || json.isEmpty()) {
|
||||||
throw new XGBoostError("Expecting non-empty feature columns' array interface");
|
throw new XGBoostError("Expecting non-empty feature columns' array interface");
|
||||||
}
|
}
|
||||||
@ -228,7 +228,7 @@ public class DMatrix {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public void setQueryId(Column column) throws XGBoostError {
|
public void setQueryId(Column column) throws XGBoostError {
|
||||||
setXGBDMatrixInfo("qid", column.getArrayInterfaceJson());
|
setXGBDMatrixInfo("qid", column.toJson());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
|
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
|
||||||
@ -362,7 +362,7 @@ public class DMatrix {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public void setLabel(Column column) throws XGBoostError {
|
public void setLabel(Column column) throws XGBoostError {
|
||||||
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
|
setXGBDMatrixInfo("label", column.toJson());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -393,7 +393,7 @@ public class DMatrix {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public void setWeight(Column column) throws XGBoostError {
|
public void setWeight(Column column) throws XGBoostError {
|
||||||
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
|
setXGBDMatrixInfo("weight", column.toJson());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -421,7 +421,7 @@ public class DMatrix {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public void setBaseMargin(Column column) throws XGBoostError {
|
public void setBaseMargin(Column column) throws XGBoostError {
|
||||||
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
|
setXGBDMatrixInfo("base_margin", column.toJson());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -104,7 +104,8 @@ void CopyInterface(std::vector<xgboost::ArrayInterface<1>> &interface_arr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyMetaInfo(Json *p_interface, dh::device_vector<float> *out, cudaStream_t stream) {
|
template <typename T>
|
||||||
|
void CopyMetaInfo(Json *p_interface, dh::device_vector<T> *out, cudaStream_t stream) {
|
||||||
auto &j_interface = *p_interface;
|
auto &j_interface = *p_interface;
|
||||||
CHECK_EQ(get<Array const>(j_interface).size(), 1);
|
CHECK_EQ(get<Array const>(j_interface).size(), 1);
|
||||||
auto object = get<Object>(get<Array>(j_interface)[0]);
|
auto object = get<Object>(get<Array>(j_interface)[0]);
|
||||||
@ -151,9 +152,11 @@ class DataIteratorProxy {
|
|||||||
std::vector<std::unique_ptr<dh::device_vector<float>>> labels_;
|
std::vector<std::unique_ptr<dh::device_vector<float>>> labels_;
|
||||||
std::vector<std::unique_ptr<dh::device_vector<float>>> weights_;
|
std::vector<std::unique_ptr<dh::device_vector<float>>> weights_;
|
||||||
std::vector<std::unique_ptr<dh::device_vector<float>>> base_margins_;
|
std::vector<std::unique_ptr<dh::device_vector<float>>> base_margins_;
|
||||||
|
std::vector<std::unique_ptr<dh::device_vector<int>>> qids_;
|
||||||
std::vector<Json> label_interfaces_;
|
std::vector<Json> label_interfaces_;
|
||||||
std::vector<Json> weight_interfaces_;
|
std::vector<Json> weight_interfaces_;
|
||||||
std::vector<Json> margin_interfaces_;
|
std::vector<Json> margin_interfaces_;
|
||||||
|
std::vector<Json> qid_interfaces_;
|
||||||
|
|
||||||
size_t it_{0};
|
size_t it_{0};
|
||||||
size_t n_batches_{0};
|
size_t n_batches_{0};
|
||||||
@ -186,11 +189,11 @@ class DataIteratorProxy {
|
|||||||
void StageMetaInfo(Json json_interface) {
|
void StageMetaInfo(Json json_interface) {
|
||||||
CHECK(!IsA<Null>(json_interface));
|
CHECK(!IsA<Null>(json_interface));
|
||||||
auto json_map = get<Object const>(json_interface);
|
auto json_map = get<Object const>(json_interface);
|
||||||
if (json_map.find("label_str") == json_map.cend()) {
|
if (json_map.find("label") == json_map.cend()) {
|
||||||
LOG(FATAL) << "Must have a label field.";
|
LOG(FATAL) << "Must have a label field.";
|
||||||
}
|
}
|
||||||
|
|
||||||
Json label = json_interface["label_str"];
|
Json label = json_interface["label"];
|
||||||
CHECK(!IsA<Null>(label));
|
CHECK(!IsA<Null>(label));
|
||||||
labels_.emplace_back(new dh::device_vector<float>);
|
labels_.emplace_back(new dh::device_vector<float>);
|
||||||
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
|
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
|
||||||
@ -200,8 +203,8 @@ class DataIteratorProxy {
|
|||||||
Json::Dump(label, &str);
|
Json::Dump(label, &str);
|
||||||
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
||||||
|
|
||||||
if (json_map.find("weight_str") != json_map.cend()) {
|
if (json_map.find("weight") != json_map.cend()) {
|
||||||
Json weight = json_interface["weight_str"];
|
Json weight = json_interface["weight"];
|
||||||
CHECK(!IsA<Null>(weight));
|
CHECK(!IsA<Null>(weight));
|
||||||
weights_.emplace_back(new dh::device_vector<float>);
|
weights_.emplace_back(new dh::device_vector<float>);
|
||||||
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
|
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
|
||||||
@ -211,8 +214,8 @@ class DataIteratorProxy {
|
|||||||
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (json_map.find("basemargin_str") != json_map.cend()) {
|
if (json_map.find("baseMargin") != json_map.cend()) {
|
||||||
Json basemargin = json_interface["basemargin_str"];
|
Json basemargin = json_interface["baseMargin"];
|
||||||
base_margins_.emplace_back(new dh::device_vector<float>);
|
base_margins_.emplace_back(new dh::device_vector<float>);
|
||||||
CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_);
|
CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_);
|
||||||
margin_interfaces_.emplace_back(basemargin);
|
margin_interfaces_.emplace_back(basemargin);
|
||||||
@ -220,6 +223,16 @@ class DataIteratorProxy {
|
|||||||
Json::Dump(basemargin, &str);
|
Json::Dump(basemargin, &str);
|
||||||
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (json_map.find("qid") != json_map.cend()) {
|
||||||
|
Json qid = json_interface["qid"];
|
||||||
|
qids_.emplace_back(new dh::device_vector<int>);
|
||||||
|
CopyMetaInfo(&qid, qids_.back().get(), copy_stream_);
|
||||||
|
qid_interfaces_.emplace_back(qid);
|
||||||
|
|
||||||
|
Json::Dump(qid, &str);
|
||||||
|
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CloseJvmBatch() {
|
void CloseJvmBatch() {
|
||||||
@ -249,11 +262,11 @@ class DataIteratorProxy {
|
|||||||
// batch should be ColumnBatch from jvm
|
// batch should be ColumnBatch from jvm
|
||||||
jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_);
|
jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_);
|
||||||
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_);
|
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_);
|
||||||
jmethodID getArrayInterfaceJson = CheckJvmCall(jenv_->GetMethodID(
|
jmethodID toJson = CheckJvmCall(jenv_->GetMethodID(
|
||||||
batch_class, "getArrayInterfaceJson", "()Ljava/lang/String;"), jenv_);
|
batch_class, "toJson", "()Ljava/lang/String;"), jenv_);
|
||||||
|
|
||||||
auto jinterface =
|
auto jinterface =
|
||||||
static_cast<jstring>(jenv_->CallObjectMethod(batch, getArrayInterfaceJson));
|
static_cast<jstring>(jenv_->CallObjectMethod(batch, toJson));
|
||||||
CheckJvmCall(jinterface, jenv_);
|
CheckJvmCall(jinterface, jenv_);
|
||||||
char const *c_interface_str =
|
char const *c_interface_str =
|
||||||
CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_);
|
CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_);
|
||||||
@ -281,7 +294,7 @@ class DataIteratorProxy {
|
|||||||
CHECK(!IsA<Null>(json_interface));
|
CHECK(!IsA<Null>(json_interface));
|
||||||
StageMetaInfo(json_interface);
|
StageMetaInfo(json_interface);
|
||||||
|
|
||||||
Json features = json_interface["features_str"];
|
Json features = json_interface["features"];
|
||||||
auto json_columns = get<Array const>(features);
|
auto json_columns = get<Array const>(features);
|
||||||
std::vector<ArrayInterface<1>> interfaces;
|
std::vector<ArrayInterface<1>> interfaces;
|
||||||
|
|
||||||
@ -337,6 +350,12 @@ class DataIteratorProxy {
|
|||||||
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (n_batches_ == this->qid_interfaces_.size()) {
|
||||||
|
auto const &qid = this->qid_interfaces_.at(it_);
|
||||||
|
Json::Dump(qid, &str);
|
||||||
|
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
// Data
|
// Data
|
||||||
auto const &json_interface = host_columns_.at(it_)->interfaces;
|
auto const &json_interface = host_columns_.at(it_)->interfaces;
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user