temp merge, disable 1 line, SetValid
This commit is contained in:
@@ -3,161 +3,15 @@
|
||||
[](https://xgboost.readthedocs.org/en/latest/jvm/index.html)
|
||||
[](../LICENSE)
|
||||
|
||||
[Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) |
|
||||
[Documentation](https://xgboost.readthedocs.org/en/stable/jvm/index.html) |
|
||||
[Resources](../demo/README.md) |
|
||||
[Release Notes](../NEWS.md)
|
||||
|
||||
XGBoost4J is the JVM package of xgboost. It brings all the optimizations
|
||||
and power xgboost into JVM ecosystem.
|
||||
XGBoost4J is the JVM package of xgboost. It brings all the optimizations and power xgboost
|
||||
into JVM ecosystem.
|
||||
|
||||
- Train XGBoost models in scala and java with easy customizations.
|
||||
- Run distributed xgboost natively on jvm frameworks such as
|
||||
Apache Flink and Apache Spark.
|
||||
- Train XGBoost models in scala and java with easy customization.
|
||||
- Run distributed xgboost natively on jvm frameworks such as Apache Flink and Apache
|
||||
Spark.
|
||||
|
||||
You can find more about XGBoost on [Documentation](https://xgboost.readthedocs.org/en/latest/jvm/index.html) and [Resource Page](../demo/README.md).
|
||||
|
||||
## Add Maven Dependency
|
||||
|
||||
XGBoost4J, XGBoost4J-Spark, etc. in maven repository is compiled with g++-4.8.5.
|
||||
|
||||
### Access release version
|
||||
|
||||
<b>Maven</b>
|
||||
|
||||
```
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_2.12</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_2.12</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
```
|
||||
or
|
||||
```
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_2.13</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_2.13</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
<b>sbt</b>
|
||||
```sbt
|
||||
libraryDependencies ++= Seq(
|
||||
"ml.dmlc" %% "xgboost4j" % "latest_version_num",
|
||||
"ml.dmlc" %% "xgboost4j-spark" % "latest_version_num"
|
||||
)
|
||||
```
|
||||
|
||||
For the latest release version number, please check [here](https://github.com/dmlc/xgboost/releases).
|
||||
|
||||
|
||||
### Access SNAPSHOT version
|
||||
|
||||
First add the following Maven repository hosted by the XGBoost project:
|
||||
|
||||
<b>Maven</b>:
|
||||
|
||||
```xml
|
||||
<repository>
|
||||
<id>XGBoost4J Snapshot Repo</id>
|
||||
<name>XGBoost4J Snapshot Repo</name>
|
||||
<url>https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/</url>
|
||||
</repository>
|
||||
```
|
||||
|
||||
<b>sbt</b>:
|
||||
|
||||
```sbt
|
||||
resolvers += "XGBoost4J Snapshot Repo" at "https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/"
|
||||
```
|
||||
|
||||
Then add XGBoost4J as a dependency:
|
||||
|
||||
<b>Maven</b>
|
||||
|
||||
```
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_2.12</artifactId>
|
||||
<version>latest_version_num-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_2.12</artifactId>
|
||||
<version>latest_version_num-SNAPSHOT</version>
|
||||
</dependency>
|
||||
```
|
||||
or with scala 2.13
|
||||
```
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_2.13</artifactId>
|
||||
<version>latest_version_num-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_2.13</artifactId>
|
||||
<version>latest_version_num-SNAPSHOT</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
<b>sbt</b>
|
||||
```sbt
|
||||
libraryDependencies ++= Seq(
|
||||
"ml.dmlc" %% "xgboost4j" % "latest_version_num-SNAPSHOT",
|
||||
"ml.dmlc" %% "xgboost4j-spark" % "latest_version_num-SNAPSHOT"
|
||||
)
|
||||
```
|
||||
|
||||
For the latest release version number, please check [the repository listing](https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/list.html).
|
||||
|
||||
### GPU algorithm
|
||||
To enable the GPU algorithm (`tree_method='gpu_hist'`), use artifacts `xgboost4j-gpu_2.12` and `xgboost4j-spark-gpu_2.12` instead.
|
||||
Note that scala 2.13 is not supported by the [NVIDIA/spark-rapids#1525](https://github.com/NVIDIA/spark-rapids/issues/1525) yet, so the GPU algorithm can only be used with scala 2.12.
|
||||
|
||||
## Examples
|
||||
|
||||
Full code examples for Scala, Java, Apache Spark, and Apache Flink can
|
||||
be found in the [examples package](https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example).
|
||||
|
||||
**NOTE on LIBSVM Format**:
|
||||
|
||||
There is an inconsistent issue between XGBoost4J-Spark and other language bindings of XGBoost.
|
||||
|
||||
When users use Spark to load trainingset/testset in LIBSVM format with the following code snippet:
|
||||
|
||||
```scala
|
||||
spark.read.format("libsvm").load("trainingset_libsvm")
|
||||
```
|
||||
|
||||
Spark assumes that the dataset is 1-based indexed. However, when you do prediction with other bindings of XGBoost (e.g. Python API of XGBoost), XGBoost assumes that the dataset is 0-based indexed. It creates a pitfall for the users who train model with Spark but predict with the dataset in the same format in other bindings of XGBoost.
|
||||
|
||||
## Development
|
||||
|
||||
You can build/package xgboost4j locally with the following steps:
|
||||
|
||||
**Linux:**
|
||||
1. Ensure [Docker for Linux](https://docs.docker.com/install/) is installed.
|
||||
2. Clone this repo: `git clone --recursive https://github.com/dmlc/xgboost.git`
|
||||
3. Run the following command:
|
||||
- With Tests: `./xgboost/jvm-packages/dev/build-linux.sh`
|
||||
- Skip Tests: `./xgboost/jvm-packages/dev/build-linux.sh --skip-tests`
|
||||
|
||||
**Windows:**
|
||||
1. Ensure [Docker for Windows](https://docs.docker.com/docker-for-windows/install/) is installed.
|
||||
2. Clone this repo: `git clone --recursive https://github.com/dmlc/xgboost.git`
|
||||
3. Run the following command:
|
||||
- With Tests: `.\xgboost\jvm-packages\dev\build-linux.cmd`
|
||||
- Skip Tests: `.\xgboost\jvm-packages\dev\build-linux.cmd --skip-tests`
|
||||
|
||||
*Note: this will create jars for deployment on Linux machines.*
|
||||
You can find more about XGBoost on [Documentation](https://xgboost.readthedocs.org/en/stable/jvm/index.html) and [Resource Page](../demo/README.md).
|
||||
3
jvm-packages/dev/.gitattributes
vendored
3
jvm-packages/dev/.gitattributes
vendored
@@ -1,3 +0,0 @@
|
||||
# Set line endings to LF, even on Windows. Otherwise, execution within Docker fails.
|
||||
# See https://help.github.com/articles/dealing-with-line-endings/
|
||||
*.sh text eol=lf
|
||||
1
jvm-packages/dev/.gitignore
vendored
1
jvm-packages/dev/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
.m2
|
||||
@@ -1,58 +0,0 @@
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
FROM centos:7
|
||||
|
||||
# Install all basic requirements
|
||||
RUN \
|
||||
yum -y update && \
|
||||
yum install -y bzip2 make tar unzip wget xz git centos-release-scl yum-utils java-1.8.0-openjdk-devel && \
|
||||
yum-config-manager --enable centos-sclo-rh-testing && \
|
||||
yum -y update && \
|
||||
yum install -y devtoolset-7-gcc devtoolset-7-binutils devtoolset-7-gcc-c++ && \
|
||||
# Python
|
||||
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
|
||||
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python && \
|
||||
# CMake
|
||||
wget -nv -nc https://cmake.org/files/v3.18/cmake-3.18.3-Linux-x86_64.sh --no-check-certificate && \
|
||||
bash cmake-3.18.3-Linux-x86_64.sh --skip-license --prefix=/usr && \
|
||||
# Maven
|
||||
wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \
|
||||
tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \
|
||||
ln -s /opt/apache-maven-3.6.1/ /opt/maven
|
||||
|
||||
# Set the required environment variables
|
||||
ENV PATH=/opt/python/bin:/opt/maven/bin:$PATH
|
||||
ENV CC=/opt/rh/devtoolset-7/root/usr/bin/gcc
|
||||
ENV CXX=/opt/rh/devtoolset-7/root/usr/bin/c++
|
||||
ENV CPP=/opt/rh/devtoolset-7/root/usr/bin/cpp
|
||||
ENV JAVA_HOME=/usr/lib/jvm/java
|
||||
|
||||
# Install Python packages
|
||||
RUN \
|
||||
pip install numpy pytest scipy scikit-learn wheel kubernetes urllib3==1.22 awscli
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
# Install lightweight sudo (not bound to TTY)
|
||||
RUN set -ex; \
|
||||
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
|
||||
chmod +x /usr/local/bin/gosu && \
|
||||
gosu nobody true
|
||||
|
||||
WORKDIR /xgboost
|
||||
@@ -1,44 +0,0 @@
|
||||
@echo off
|
||||
|
||||
rem
|
||||
rem Licensed to the Apache Software Foundation (ASF) under one
|
||||
rem or more contributor license agreements. See the NOTICE file
|
||||
rem distributed with this work for additional information
|
||||
rem regarding copyright ownership. The ASF licenses this file
|
||||
rem to you under the Apache License, Version 2.0 (the
|
||||
rem "License"); you may not use this file except in compliance
|
||||
rem with the License. You may obtain a copy of the License at
|
||||
rem
|
||||
rem http://www.apache.org/licenses/LICENSE-2.0
|
||||
rem
|
||||
rem Unless required by applicable law or agreed to in writing,
|
||||
rem software distributed under the License is distributed on an
|
||||
rem "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
rem KIND, either express or implied. See the License for the
|
||||
rem specific language governing permissions and limitations
|
||||
rem under the License.
|
||||
rem
|
||||
|
||||
rem The the local path of this file
|
||||
set "BASEDIR=%~dp0"
|
||||
|
||||
rem The local path of .m2 directory for maven
|
||||
set "M2DIR=%BASEDIR%\.m2\"
|
||||
|
||||
rem Create a local .m2 directory if needed
|
||||
if not exist "%M2DIR%" mkdir "%M2DIR%"
|
||||
|
||||
rem Build and tag the Dockerfile
|
||||
docker build -t dmlc/xgboost4j-build %BASEDIR%
|
||||
|
||||
docker run^
|
||||
-it^
|
||||
--rm^
|
||||
--memory 12g^
|
||||
--env JAVA_OPTS="-Xmx9g"^
|
||||
--env MAVEN_OPTS="-Xmx3g"^
|
||||
--ulimit core=-1^
|
||||
--volume %BASEDIR%\..\..:/xgboost^
|
||||
--volume %M2DIR%:/root/.m2^
|
||||
dmlc/xgboost4j-build^
|
||||
/xgboost/jvm-packages/dev/package-linux.sh "%*"
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
BASEDIR="$( cd "$( dirname "$0" )" && pwd )" # the directory of this file
|
||||
|
||||
docker build -t dmlc/xgboost4j-build "${BASEDIR}" # build and tag the Dockerfile
|
||||
|
||||
exec docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--memory 12g \
|
||||
--env JAVA_OPTS="-Xmx9g" \
|
||||
--env MAVEN_OPTS="-Xmx3g -Dmaven.repo.local=/xgboost/jvm-packages/dev/.m2" \
|
||||
--env CI_BUILD_UID=`id -u` \
|
||||
--env CI_BUILD_GID=`id -g` \
|
||||
--env CI_BUILD_USER=`id -un` \
|
||||
--env CI_BUILD_GROUP=`id -gn` \
|
||||
--ulimit core=-1 \
|
||||
--volume "${BASEDIR}/../..":/xgboost \
|
||||
dmlc/xgboost4j-build \
|
||||
/xgboost/tests/ci_build/entrypoint.sh jvm-packages/dev/package-linux.sh "$@"
|
||||
|
||||
# CI_BUILD_UID, CI_BUILD_GID, CI_BUILD_USER, CI_BUILD_GROUP
|
||||
# are used by entrypoint.sh to create the user with the same uid in a container
|
||||
# so all produced artifacts would be owned by your host user
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
cd jvm-packages
|
||||
|
||||
case "$1" in
|
||||
--skip-tests) SKIP_TESTS=true ;;
|
||||
"") SKIP_TESTS=false ;;
|
||||
esac
|
||||
|
||||
if [[ -n ${SKIP_TESTS} ]]; then
|
||||
if [[ ${SKIP_TESTS} == "true" ]]; then
|
||||
mvn --batch-mode clean package -DskipTests
|
||||
elif [[ ${SKIP_TESTS} == "false" ]]; then
|
||||
mvn --batch-mode clean package
|
||||
fi
|
||||
else
|
||||
echo "Usage: $0 [--skip-tests]"
|
||||
exit 1
|
||||
fi
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>XGBoost JVM Package</name>
|
||||
<description>JVM Package for XGBoost</description>
|
||||
@@ -35,19 +35,19 @@
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<flink.version>1.17.1</flink.version>
|
||||
<junit.version>4.13.2</junit.version>
|
||||
<spark.version>3.4.0</spark.version>
|
||||
<spark.version.gpu>3.3.2</spark.version.gpu>
|
||||
<spark.version>3.4.1</spark.version>
|
||||
<spark.version.gpu>3.4.1</spark.version.gpu>
|
||||
<scala.version>2.12.18</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>3.3.5</hadoop.version>
|
||||
<hadoop.version>3.3.6</hadoop.version>
|
||||
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
|
||||
<log.capi.invocation>OFF</log.capi.invocation>
|
||||
<use.cuda>OFF</use.cuda>
|
||||
<cudf.version>23.04.0</cudf.version>
|
||||
<spark.rapids.version>23.04.1</spark.rapids.version>
|
||||
<cudf.version>23.08.0</cudf.version>
|
||||
<spark.rapids.version>23.08.1</spark.rapids.version>
|
||||
<cudf.classifier>cuda11</cudf.classifier>
|
||||
<scalatest.version>3.2.16</scalatest.version>
|
||||
<scala-collection-compat.version>2.10.0</scala-collection-compat.version>
|
||||
<scala-collection-compat.version>2.11.0</scala-collection-compat.version>
|
||||
</properties>
|
||||
<repositories>
|
||||
<repository>
|
||||
@@ -78,7 +78,7 @@
|
||||
<id>scala-2.13</id>
|
||||
<properties>
|
||||
<scala.binary.version>2.13</scala.binary.version>
|
||||
<scala.version>2.13.10</scala.version>
|
||||
<scala.version>2.13.11</scala.version>
|
||||
</properties>
|
||||
</profile>
|
||||
|
||||
@@ -91,6 +91,9 @@
|
||||
<value>ON</value>
|
||||
</property>
|
||||
</activation>
|
||||
<properties>
|
||||
<use.cuda>ON</use.cuda>
|
||||
</properties>
|
||||
<modules>
|
||||
<module>xgboost4j-gpu</module>
|
||||
<module>xgboost4j-spark-gpu</module>
|
||||
@@ -470,28 +473,11 @@
|
||||
</plugins>
|
||||
</reporting>
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.esotericsoftware</groupId>
|
||||
<artifactId>kryo</artifactId>
|
||||
<version>5.5.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-compiler</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang.modules</groupId>
|
||||
<artifactId>scala-collection-compat_${scala.binary.version}</artifactId>
|
||||
<version>${scala-collection-compat.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-logging</groupId>
|
||||
<artifactId>commons-logging</artifactId>
|
||||
|
||||
@@ -6,11 +6,11 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<name>xgboost4j-example</name>
|
||||
<artifactId>xgboost4j-example_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
@@ -40,20 +40,20 @@ object SparkMLlibPipeline {
|
||||
val nativeModelPath = args(1)
|
||||
val pipelineModelPath = args(2)
|
||||
|
||||
val (treeMethod, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||
("gpu_hist", 1)
|
||||
} else ("auto", 2)
|
||||
val (device, numWorkers) = if (args.length == 4 && args(3) == "gpu") {
|
||||
("cuda", 1)
|
||||
} else ("cpu", 2)
|
||||
|
||||
val spark = SparkSession
|
||||
.builder()
|
||||
.appName("XGBoost4J-Spark Pipeline Example")
|
||||
.getOrCreate()
|
||||
|
||||
run(spark, inputPath, nativeModelPath, pipelineModelPath, treeMethod, numWorkers)
|
||||
run(spark, inputPath, nativeModelPath, pipelineModelPath, device, numWorkers)
|
||||
.show(false)
|
||||
}
|
||||
private[spark] def run(spark: SparkSession, inputPath: String, nativeModelPath: String,
|
||||
pipelineModelPath: String, treeMethod: String,
|
||||
pipelineModelPath: String, device: String,
|
||||
numWorkers: Int): DataFrame = {
|
||||
|
||||
// Load dataset
|
||||
@@ -82,13 +82,14 @@ object SparkMLlibPipeline {
|
||||
.setOutputCol("classIndex")
|
||||
.fit(training)
|
||||
val booster = new XGBoostClassifier(
|
||||
Map("eta" -> 0.1f,
|
||||
Map(
|
||||
"eta" -> 0.1f,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod
|
||||
"device" -> device
|
||||
)
|
||||
)
|
||||
booster.setFeaturesCol("features")
|
||||
|
||||
@@ -31,18 +31,18 @@ object SparkTraining {
|
||||
sys.exit(1)
|
||||
}
|
||||
|
||||
val (treeMethod, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||
("gpu_hist", 1)
|
||||
} else ("auto", 2)
|
||||
val (device, numWorkers) = if (args.length == 2 && args(1) == "gpu") {
|
||||
("cuda", 1)
|
||||
} else ("cpu", 2)
|
||||
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
val inputPath = args(0)
|
||||
val results: DataFrame = run(spark, inputPath, treeMethod, numWorkers)
|
||||
val results: DataFrame = run(spark, inputPath, device, numWorkers)
|
||||
results.show()
|
||||
}
|
||||
|
||||
private[spark] def run(spark: SparkSession, inputPath: String,
|
||||
treeMethod: String, numWorkers: Int): DataFrame = {
|
||||
device: String, numWorkers: Int): DataFrame = {
|
||||
val schema = new StructType(Array(
|
||||
StructField("sepal length", DoubleType, true),
|
||||
StructField("sepal width", DoubleType, true),
|
||||
@@ -80,7 +80,7 @@ private[spark] def run(spark: SparkSession, inputPath: String,
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 100,
|
||||
"num_workers" -> numWorkers,
|
||||
"tree_method" -> treeMethod,
|
||||
"device" -> device,
|
||||
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
|
||||
val xgbClassifier = new XGBoostClassifier(xgbParam).
|
||||
setFeaturesCol("features").
|
||||
|
||||
@@ -104,7 +104,7 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
|
||||
test("Smoke test for SparkMLlibPipeline example") {
|
||||
SparkMLlibPipeline.run(spark, pathToTestDataset.toString, "target/native-model",
|
||||
"target/pipeline-model", "auto", 2)
|
||||
"target/pipeline-model", "cpu", 2)
|
||||
}
|
||||
|
||||
test("Smoke test for SparkTraining example") {
|
||||
@@ -118,6 +118,6 @@ class SparkExamplesTest extends AnyFunSuite with BeforeAndAfterAll {
|
||||
.config("spark.task.cpus", 1)
|
||||
.getOrCreate()
|
||||
|
||||
SparkTraining.run(spark, pathToTestDataset.toString, "auto", 2)
|
||||
SparkTraining.run(spark, pathToTestDataset.toString, "cpu", 2)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,12 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<name>xgboost4j-flink</name>
|
||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
<properties>
|
||||
<flink-ml.version>2.2.0</flink-ml.version>
|
||||
</properties>
|
||||
|
||||
@@ -6,14 +6,29 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-gpu_${scala.binary.version}</artifactId>
|
||||
<name>xgboost4j-gpu</name>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-compiler</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang.modules</groupId>
|
||||
<artifactId>scala-collection-compat_${scala.binary.version}</artifactId>
|
||||
<version>${scala-collection-compat.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.rapids</groupId>
|
||||
<artifactId>cudf</artifactId>
|
||||
@@ -48,7 +63,7 @@
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.12.0</version>
|
||||
<version>3.13.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
@@ -77,8 +77,8 @@ public class BoosterTest {
|
||||
put("objective", "binary:logistic");
|
||||
put("num_round", round);
|
||||
put("num_workers", 1);
|
||||
put("tree_method", "gpu_hist");
|
||||
put("predictor", "gpu_predictor");
|
||||
put("tree_method", "hist");
|
||||
put("device", "cuda");
|
||||
put("max_bin", maxBin);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<name>xgboost4j-spark-gpu</name>
|
||||
<artifactId>xgboost4j-spark-gpu_${scala.binary.version}</artifactId>
|
||||
|
||||
@@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
|
||||
estimator match {
|
||||
case est: XGBoostEstimatorCommon =>
|
||||
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
||||
s"GPU train requires tree_method set to gpu_hist")
|
||||
require(
|
||||
est.isDefined(est.device) &&
|
||||
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
|
||||
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
|
||||
s"GPU train requires `device` set to `cuda` or `gpu`."
|
||||
)
|
||||
val groupName = estimator match {
|
||||
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
|
||||
regressor.getGroupCol } else ""
|
||||
@@ -280,8 +284,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
// - gpu id
|
||||
// - predictor: Force to gpu predictor since native doesn't save predictor.
|
||||
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
|
||||
booster.setParam("gpu_id", gpuId.toString)
|
||||
booster.setParam("predictor", "gpu_predictor")
|
||||
booster.setParam("device", s"cuda:$gpuId")
|
||||
logger.info("GPU transform on device: " + gpuId)
|
||||
boosterFlag.isGpuParamsSet = true;
|
||||
}
|
||||
|
||||
@@ -282,7 +282,7 @@ object SparkSessionHolder extends Logging {
|
||||
logDebug(s"SETTING CONF: ${conf.getAll.toMap}")
|
||||
setAllConfs(conf.getAll)
|
||||
logDebug(s"RUN WITH CONF: ${spark.conf.getAll}\n")
|
||||
spark.sparkContext.setLogLevel("WARN")
|
||||
f(spark)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -50,9 +50,12 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||
val xgbParam = Map(
|
||||
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
|
||||
"tree_method" -> "hist", "device" -> "cuda",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName
|
||||
)
|
||||
new XGBoostClassifier(xgbParam)
|
||||
.fit(trainingDf)
|
||||
}
|
||||
@@ -65,8 +68,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
|
||||
trainingDf = trainingDf.select(labelName, "f2", weightName, "f3", baseMarginName, "f1")
|
||||
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist")
|
||||
val xgbParam = Map(
|
||||
"eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob",
|
||||
"num_class" -> 3, "num_round" -> 5, "num_workers" -> 1,
|
||||
"tree_method" -> "hist", "device" -> "cuda"
|
||||
)
|
||||
new XGBoostClassifier(xgbParam)
|
||||
.setFeaturesCol(featureNames)
|
||||
.setLabelCol(labelName)
|
||||
@@ -127,7 +133,7 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("Throw exception when tree method is not set to gpu_hist") {
|
||||
test("Throw exception when device is not set to cuda") {
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||
@@ -139,12 +145,11 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
.setLabelCol(labelName)
|
||||
.fit(trainingDf)
|
||||
}
|
||||
assert(thrown.getMessage.contains("GPU train requires tree_method set to gpu_hist"))
|
||||
assert(thrown.getMessage.contains("GPU train requires `device` set to `cuda`"))
|
||||
}
|
||||
}
|
||||
|
||||
test("Train with eval") {
|
||||
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
|
||||
@@ -184,4 +189,24 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("device ordinal should not be specified") {
|
||||
withGpuSparkSession() { spark =>
|
||||
import spark.implicits._
|
||||
val trainingDf = trainingData.toDF(allColumnNames: _*)
|
||||
val params = Map(
|
||||
"objective" -> "multi:softprob",
|
||||
"num_class" -> 3,
|
||||
"num_round" -> 5,
|
||||
"num_workers" -> 1
|
||||
)
|
||||
val thrown = intercept[IllegalArgumentException] {
|
||||
new XGBoostClassifier(params)
|
||||
.setFeaturesCol(featureNames)
|
||||
.setLabelCol(labelName)
|
||||
.setDevice("cuda:1")
|
||||
.fit(trainingDf)
|
||||
}
|
||||
assert(thrown.getMessage.contains("`cuda` or `gpu`"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -40,7 +40,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
test("The transform result should be same for several runs on same model") {
|
||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
@@ -54,10 +54,30 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("Tree method gpu_hist still works") {
|
||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||
val params = Map(
|
||||
"tree_method" -> "gpu_hist",
|
||||
"features_cols" -> featureNames,
|
||||
"label_col" -> labelName,
|
||||
"num_round" -> 10,
|
||||
"num_workers" -> 1
|
||||
)
|
||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
// Get a model
|
||||
val model = new XGBoostRegressor(params).fit(originalDf)
|
||||
val left = model.transform(testDf).collect()
|
||||
val right = model.transform(testDf).collect()
|
||||
// The left should be same with right
|
||||
assert(compareResults(true, 0.000001, left, right))
|
||||
}
|
||||
}
|
||||
|
||||
test("use weight") {
|
||||
withGpuSparkSession(enableCsvConf()) { spark =>
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "reg:squarederror",
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "hist", "device" -> "cuda",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
@@ -88,7 +108,8 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
val classifier = new XGBoostRegressor(xgbParam)
|
||||
.setFeaturesCol(featureNames)
|
||||
.setLabelCol(labelName)
|
||||
.setTreeMethod("gpu_hist")
|
||||
.setTreeMethod("hist")
|
||||
.setDevice("cuda")
|
||||
(classifier.fit(rawInput), testDf)
|
||||
}
|
||||
|
||||
@@ -175,7 +196,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
val classifier = new XGBoostRegressor(xgbParam)
|
||||
.setFeaturesCol(featureNames)
|
||||
.setLabelCol(labelName)
|
||||
.setTreeMethod("gpu_hist")
|
||||
.setDevice("cuda")
|
||||
classifier.fit(rawInput)
|
||||
}
|
||||
|
||||
@@ -234,5 +255,4 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
|
||||
assert(testDf.count() === ret.length)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<name>xgboost4j-spark</name>
|
||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||
|
||||
@@ -23,7 +23,6 @@ import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@@ -55,9 +54,6 @@ object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L)
|
||||
}
|
||||
|
||||
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||
maximizeEvalMetrics: Boolean)
|
||||
|
||||
private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
|
||||
|
||||
private[scala] case class XGBoostExecutionParams(
|
||||
@@ -71,10 +67,12 @@ private[scala] case class XGBoostExecutionParams(
|
||||
trackerConf: TrackerConf,
|
||||
checkpointParam: Option[ExternalCheckpointParams],
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
earlyStoppingRounds: Int,
|
||||
cacheTrainingSet: Boolean,
|
||||
treeMethod: Option[String],
|
||||
isLocal: Boolean) {
|
||||
device: Option[String],
|
||||
isLocal: Boolean,
|
||||
featureNames: Option[Array[String]],
|
||||
featureTypes: Option[Array[String]]) {
|
||||
|
||||
private var rawParamMap: Map[String, Any] = _
|
||||
|
||||
@@ -95,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
|
||||
private val overridedParams = overrideParams(rawParams, sc)
|
||||
|
||||
validateSparkSslConf()
|
||||
|
||||
/**
|
||||
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
|
||||
* If so, throw an exception unless this safety measure has been explicitly overridden
|
||||
* via conf `xgboost.spark.ignoreSsl`.
|
||||
*/
|
||||
private def validateSparkSslConf: Unit = {
|
||||
private def validateSparkSslConf(): Unit = {
|
||||
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
|
||||
SparkSession.getActiveSession match {
|
||||
case Some(ss) =>
|
||||
@@ -144,83 +144,92 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
val numEarlyStoppingRounds = overridedParams.getOrElse(
|
||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||
overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
|
||||
if (numEarlyStoppingRounds > 0 &&
|
||||
!overridedParams.contains("maximize_evaluation_metrics")) {
|
||||
if (overridedParams.getOrElse("custom_eval", null) != null) {
|
||||
if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) {
|
||||
throw new IllegalArgumentException("custom_eval does not support early stopping")
|
||||
}
|
||||
val eval_metric = overridedParams("eval_metric").toString
|
||||
val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric
|
||||
logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize)
|
||||
overridedParams += ("maximize_evaluation_metrics" -> maximize)
|
||||
}
|
||||
overridedParams
|
||||
}
|
||||
|
||||
/**
|
||||
* The Map parameters accepted by estimator's constructor may have string type,
|
||||
* Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
|
||||
* kind of parameters into the correct type in the function.
|
||||
*
|
||||
* @return XGBoostExecutionParams
|
||||
*/
|
||||
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||
val useExternalMemory = overridedParams
|
||||
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
|
||||
|
||||
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
|
||||
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
|
||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||
val allowNonZeroForMissing = overridedParams
|
||||
.getOrElse("allow_non_zero_for_missing", false)
|
||||
.asInstanceOf[Boolean]
|
||||
validateSparkSslConf
|
||||
var treeMethod: Option[String] = None
|
||||
if (overridedParams.contains("tree_method")) {
|
||||
require(overridedParams("tree_method") == "hist" ||
|
||||
overridedParams("tree_method") == "approx" ||
|
||||
overridedParams("tree_method") == "auto" ||
|
||||
overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" +
|
||||
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
|
||||
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
|
||||
}
|
||||
if (overridedParams.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
"'eval_set_names'")
|
||||
}
|
||||
require(nWorkers > 0, "you must specify more than 0 workers")
|
||||
if (obj != null) {
|
||||
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
|
||||
"is not defined, you have to specify the objective type as classification or regression" +
|
||||
" with a customized objective function")
|
||||
}
|
||||
|
||||
var trainTestRatio = 1.0
|
||||
if (overridedParams.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
|
||||
"'eval_set_names'")
|
||||
trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
|
||||
}
|
||||
|
||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||
val useExternalMemory = overridedParams
|
||||
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
|
||||
|
||||
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
|
||||
val allowNonZeroForMissing = overridedParams
|
||||
.getOrElse("allow_non_zero_for_missing", false)
|
||||
.asInstanceOf[Boolean]
|
||||
|
||||
val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
|
||||
// back-compatible with "gpu_hist"
|
||||
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
|
||||
Some("cuda")
|
||||
} else overridedParams.get("device").map(_.toString)
|
||||
|
||||
require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")),
|
||||
"The tree method \"approx\" is not yet supported for Spark GPU cluster")
|
||||
|
||||
val trackerConf = overridedParams.get("tracker_conf") match {
|
||||
case None => TrackerConf()
|
||||
case Some(conf: TrackerConf) => conf
|
||||
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
|
||||
"instance of TrackerConf.")
|
||||
}
|
||||
val checkpointParam =
|
||||
ExternalCheckpointParams.extractParams(overridedParams)
|
||||
|
||||
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
||||
.asInstanceOf[Double]
|
||||
val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)
|
||||
|
||||
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
|
||||
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
|
||||
|
||||
val earlyStoppingRounds = overridedParams.getOrElse(
|
||||
"num_early_stopping_rounds", 0).asInstanceOf[Int]
|
||||
val maximizeEvalMetrics = overridedParams.getOrElse(
|
||||
"maximize_evaluation_metrics", true).asInstanceOf[Boolean]
|
||||
val xgbExecEarlyStoppingParams = XGBoostExecutionEarlyStoppingParams(earlyStoppingRounds,
|
||||
maximizeEvalMetrics)
|
||||
|
||||
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
||||
.asInstanceOf[Boolean]
|
||||
|
||||
val featureNames = if (overridedParams.contains("feature_names")) {
|
||||
Some(overridedParams("feature_names").asInstanceOf[Array[String]])
|
||||
} else None
|
||||
val featureTypes = if (overridedParams.contains("feature_types")){
|
||||
Some(overridedParams("feature_types").asInstanceOf[Array[String]])
|
||||
} else None
|
||||
|
||||
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
||||
missing, allowNonZeroForMissing, trackerConf,
|
||||
checkpointParam,
|
||||
inputParams,
|
||||
xgbExecEarlyStoppingParams,
|
||||
earlyStoppingRounds,
|
||||
cacheTrainingSet,
|
||||
treeMethod,
|
||||
isLocal)
|
||||
device,
|
||||
isLocal,
|
||||
featureNames,
|
||||
featureTypes
|
||||
)
|
||||
xgbExecParam.setRawParamMap(overridedParams)
|
||||
xgbExecParam
|
||||
}
|
||||
@@ -301,12 +310,12 @@ object XGBoost extends Serializable {
|
||||
|
||||
watches = buildWatchesAndCheck(buildWatches)
|
||||
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
|
||||
var params = xgbExecutionParam.toMap
|
||||
if (xgbExecutionParam.treeMethod.exists(m => m == "gpu_hist")) {
|
||||
if (xgbExecutionParam.device.exists(m => (m == "cuda" || m == "gpu"))) {
|
||||
val gpuId = if (xgbExecutionParam.isLocal) {
|
||||
// For local mode, force gpu id to primary device
|
||||
0
|
||||
@@ -314,8 +323,9 @@ object XGBoost extends Serializable {
|
||||
getGPUAddrFromResources
|
||||
}
|
||||
logger.info("Leveraging gpu device " + gpuId + " to train")
|
||||
params = params + ("gpu_id" -> gpuId)
|
||||
params = params + ("device" -> s"cuda:$gpuId")
|
||||
}
|
||||
|
||||
val booster = if (makeCheckpoint) {
|
||||
SXGBoost.trainAndSaveCheckpoint(
|
||||
watches.toMap("train"), params, numRounds,
|
||||
@@ -403,7 +413,10 @@ object XGBoost extends Serializable {
|
||||
|
||||
}}
|
||||
|
||||
val (booster, metrics) = boostersAndMetrics.collect()(0)
|
||||
// The repartition step is to make training stage as ShuffleMapStage, so that when one
|
||||
// of the training task fails the training stage can retry. ResultStage won't retry when
|
||||
// it fails.
|
||||
val (booster, metrics) = boostersAndMetrics.repartition(1).collect()(0)
|
||||
val trackerReturnVal = tracker.waitFor(0L)
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
if (trackerReturnVal != 0) {
|
||||
@@ -531,6 +544,16 @@ private object Watches {
|
||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||
|
||||
if (xgbExecutionParams.featureNames.isDefined) {
|
||||
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
||||
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
||||
}
|
||||
|
||||
if (xgbExecutionParams.featureTypes.isDefined) {
|
||||
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
||||
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
||||
}
|
||||
|
||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
||||
}
|
||||
|
||||
@@ -643,6 +666,15 @@ private object Watches {
|
||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||
|
||||
if (xgbExecutionParams.featureNames.isDefined) {
|
||||
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
||||
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
||||
}
|
||||
if (xgbExecutionParams.featureTypes.isDefined) {
|
||||
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
||||
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
||||
}
|
||||
|
||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,6 +93,8 @@ class XGBoostClassifier (
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setDevice(value: String): this.type = set(device, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
@@ -139,6 +141,12 @@ class XGBoostClassifier (
|
||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
||||
set(singlePrecisionHistogram, value)
|
||||
|
||||
def setFeatureNames(value: Array[String]): this.type =
|
||||
set(featureNames, value)
|
||||
|
||||
def setFeatureTypes(value: Array[String]): this.type =
|
||||
set(featureTypes, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
|
||||
@@ -95,6 +95,8 @@ class XGBoostRegressor (
|
||||
|
||||
def setTreeMethod(value: String): this.type = set(treeMethod, value)
|
||||
|
||||
def setDevice(value: String): this.type = set(device, value)
|
||||
|
||||
def setGrowPolicy(value: String): this.type = set(growPolicy, value)
|
||||
|
||||
def setMaxBins(value: Int): this.type = set(maxBins, value)
|
||||
@@ -141,6 +143,12 @@ class XGBoostRegressor (
|
||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
||||
set(singlePrecisionHistogram, value)
|
||||
|
||||
def setFeatureNames(value: Array[String]): this.type =
|
||||
set(featureNames, value)
|
||||
|
||||
def setFeatureTypes(value: Array[String]): this.type =
|
||||
set(featureTypes, value)
|
||||
|
||||
// called at the start of fit/train when 'eval_metric' is not defined
|
||||
private def setupDefaultEvalMetric(): String = {
|
||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||
|
||||
@@ -154,6 +154,14 @@ private[spark] trait BoosterParams extends Params {
|
||||
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
|
||||
|
||||
final def getTreeMethod: String = $(treeMethod)
|
||||
/**
|
||||
* The device for running XGBoost algorithms, options: cpu, cuda
|
||||
*/
|
||||
final val device = new Param[String](
|
||||
this, "device", "The device for running XGBoost algorithms, options: cpu, cuda"
|
||||
)
|
||||
|
||||
final def getDevice: String = $(device)
|
||||
|
||||
/**
|
||||
* growth policy for fast histogram algorithm
|
||||
|
||||
@@ -177,6 +177,21 @@ private[spark] trait GeneralParams extends Params {
|
||||
|
||||
final def getSeed: Long = $(seed)
|
||||
|
||||
/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
|
||||
* In native code, the parameter name is feature_name.
|
||||
* */
|
||||
final val featureNames = new StringArrayParam(this, "feature_names",
|
||||
"an array of feature names")
|
||||
|
||||
final def getFeatureNames: Array[String] = $(featureNames)
|
||||
|
||||
/** Feature types, q is numeric and c is categorical.
|
||||
* In native code, the parameter name is feature_type
|
||||
* */
|
||||
final val featureTypes = new StringArrayParam(this, "feature_types",
|
||||
"an array of feature types")
|
||||
|
||||
final def getFeatureTypes: Array[String] = $(featureTypes)
|
||||
}
|
||||
|
||||
trait HasLeafPredictionCol extends Params {
|
||||
@@ -269,7 +284,7 @@ private[spark] trait ParamMapFuncs extends Params {
|
||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||
paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
|
||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker,prune or" +
|
||||
s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
|
||||
s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
|
||||
}
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
|
||||
@@ -68,11 +68,13 @@ private[spark] trait LearningTaskParams extends Params {
|
||||
/**
|
||||
* Fraction of training points to use for testing.
|
||||
*/
|
||||
@Deprecated
|
||||
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
|
||||
"fraction of training points to use for testing",
|
||||
ParamValidators.inRange(0, 1))
|
||||
setDefault(trainTestRatio, 1.0)
|
||||
|
||||
@Deprecated
|
||||
final def getTrainTestRatio: Double = $(trainTestRatio)
|
||||
|
||||
/**
|
||||
@@ -112,8 +114,4 @@ private[spark] object LearningTaskParams {
|
||||
|
||||
val supportedObjectiveType = HashSet("regression", "classification")
|
||||
|
||||
val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map")
|
||||
|
||||
val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "mape", "logloss", "error", "merror",
|
||||
"mlogloss", "gamma-deviance")
|
||||
}
|
||||
|
||||
@@ -92,4 +92,15 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
classifier.getBaseScore
|
||||
}
|
||||
}
|
||||
|
||||
test("approx can't be used for gpu train") {
|
||||
val paramMap = Map("tree_method" -> "approx", "device" -> "cuda")
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val thrown = intercept[IllegalArgumentException] {
|
||||
xgb.fit(trainingDF)
|
||||
}
|
||||
assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " +
|
||||
"for Spark GPU cluster"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,8 @@ import org.apache.commons.io.IOUtils
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.json4s.{DefaultFormats, Formats}
|
||||
import org.json4s.jackson.parseJson
|
||||
|
||||
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||
|
||||
@@ -453,4 +455,25 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeUbjModelPath))
|
||||
}
|
||||
|
||||
test("native json model file should store feature_name and feature_type") {
|
||||
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
|
||||
val featureTypes = (1 to 33).map(idx => "q").toArray
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
|
||||
"num_workers" -> numWorkers, "tree_method" -> treeMethod
|
||||
)
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
.setFeatureNames(featureNames)
|
||||
.setFeatureTypes(featureTypes)
|
||||
val model = xgb.fit(trainingDF)
|
||||
val modelStr = new String(model._booster.toByteArray("json"))
|
||||
val jsonModel = parseJson(modelStr)
|
||||
implicit val formats: Formats = DefaultFormats
|
||||
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
||||
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
|
||||
assert(featureNamesInModel.length == 33)
|
||||
assert(featureTypesInModel.length == 33)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,14 +6,29 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<name>xgboost4j</name>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>2.1.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-compiler</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang.modules</groupId>
|
||||
<artifactId>scala-collection-compat_${scala.binary.version}</artifactId>
|
||||
<version>${scala-collection-compat.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-hdfs</artifactId>
|
||||
|
||||
@@ -162,6 +162,51 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature names from the Booster.
|
||||
* @return
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public final String[] getFeatureNames() throws XGBoostError {
|
||||
int numFeature = (int) getNumFeature();
|
||||
String[] out = new String[numFeature];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_name", out));
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature names to the Booster.
|
||||
*
|
||||
* @param featureNames
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureNames(String[] featureNames) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
|
||||
handle, "feature_name", featureNames));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature types from the Booster.
|
||||
* @return
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public final String[] getFeatureTypes() throws XGBoostError {
|
||||
int numFeature = (int) getNumFeature();
|
||||
String[] out = new String[numFeature];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_type", out));
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature types to the Booster.
|
||||
* @param featureTypes
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureTypes(String[] featureTypes) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
|
||||
handle, "feature_type", featureTypes));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the booster for one iteration.
|
||||
*
|
||||
@@ -173,34 +218,48 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = this.predict(dtrain, true, 0, false, false);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
this.boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration.
|
||||
* @param obj customized objective class
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = this.predict(dtrain, true, 0, false, false);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
this.boost(dtrain, iter, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||
this.boost(dtrain, 0, grad, hess);
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
* Update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration.
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||
public void boost(DMatrix dtrain, int iter, float[] grad, float[] hess) throws XGBoostError {
|
||||
if (grad.length != hess.length) {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
|
||||
dtrain.getHandle(), grad, hess));
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterTrainOneIter(handle,
|
||||
dtrain.getHandle(), iter, grad, hess));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -744,7 +803,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||
try {
|
||||
out.writeInt(version);
|
||||
out.writeObject(this.toByteArray());
|
||||
out.writeObject(this.toByteArray("ubj"));
|
||||
} catch (XGBoostError ex) {
|
||||
ex.printStackTrace();
|
||||
logger.error(ex.getMessage());
|
||||
@@ -780,7 +839,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
@Override
|
||||
public void write(Kryo kryo, Output output) {
|
||||
try {
|
||||
byte[] serObj = this.toByteArray();
|
||||
byte[] serObj = this.toByteArray("ubj");
|
||||
int serObjSize = serObj.length;
|
||||
output.writeInt(serObjSize);
|
||||
output.writeInt(version);
|
||||
|
||||
@@ -17,6 +17,8 @@ package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
@@ -30,6 +32,11 @@ import org.apache.hadoop.fs.FileSystem;
|
||||
public class XGBoost {
|
||||
private static final Log logger = LogFactory.getLog(XGBoost.class);
|
||||
|
||||
public static final String[] MAXIMIZ_METRICES = {
|
||||
"auc", "aucpr", "pre", "pre@", "map", "ndcg",
|
||||
"auc@", "aucpr@", "map@", "ndcg@",
|
||||
};
|
||||
|
||||
/**
|
||||
* load model from modelPath
|
||||
*
|
||||
@@ -158,7 +165,7 @@ public class XGBoost {
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
float bestScore;
|
||||
float bestScore = 1;
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
@@ -175,11 +182,7 @@ public class XGBoost {
|
||||
|
||||
evalNames = names.toArray(new String[names.size()]);
|
||||
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
bestScore = -Float.MAX_VALUE;
|
||||
} else {
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
|
||||
bestIteration = 0;
|
||||
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||
|
||||
@@ -198,6 +201,8 @@ public class XGBoost {
|
||||
if (booster == null) {
|
||||
// Start training on a new booster
|
||||
booster = new Booster(params, allMats);
|
||||
booster.setFeatureNames(dtrain.getFeatureNames());
|
||||
booster.setFeatureTypes(dtrain.getFeatureTypes());
|
||||
booster.loadRabitCheckpoint();
|
||||
} else {
|
||||
// Start training on an existing booster
|
||||
@@ -208,6 +213,9 @@ public class XGBoost {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
boolean initial_best_score_flag = false;
|
||||
boolean max_direction = false;
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
@@ -229,6 +237,18 @@ public class XGBoost {
|
||||
} else {
|
||||
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
|
||||
}
|
||||
|
||||
if (!initial_best_score_flag) {
|
||||
if (isMaximizeEvaluation(evalInfo, evalNames, params)) {
|
||||
max_direction = true;
|
||||
bestScore = -Float.MAX_VALUE;
|
||||
} else {
|
||||
max_direction = false;
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
initial_best_score_flag = true;
|
||||
}
|
||||
|
||||
for (int i = 0; i < metricsOut.length; i++) {
|
||||
metrics[i][iter] = metricsOut[i];
|
||||
}
|
||||
@@ -236,7 +256,7 @@ public class XGBoost {
|
||||
// If there is more than one evaluation datasets, the last one would be used
|
||||
// to determinate early stop.
|
||||
float score = metricsOut[metricsOut.length - 1];
|
||||
if (isMaximizeEvaluation(params)) {
|
||||
if (max_direction) {
|
||||
// Update best score if the current score is better (no update when equal)
|
||||
if (score > bestScore) {
|
||||
bestScore = score;
|
||||
@@ -262,9 +282,7 @@ public class XGBoost {
|
||||
break;
|
||||
}
|
||||
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
|
||||
if (shouldPrint(params, iter)){
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
Communicator.communicatorPrint(evalInfo + '\n');
|
||||
}
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
@@ -358,16 +376,50 @@ public class XGBoost {
|
||||
return iter - bestIteration >= earlyStoppingRounds;
|
||||
}
|
||||
|
||||
private static boolean isMaximizeEvaluation(Map<String, Object> params) {
|
||||
try {
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
assert(maximize != null);
|
||||
return Boolean.valueOf(maximize);
|
||||
} catch (Exception ex) {
|
||||
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
|
||||
" allowed value: true/false", ex);
|
||||
throw ex;
|
||||
private static String getMetricNameFromlog(String evalInfo, String[] evalNames) {
|
||||
String regexPattern = Pattern.quote(evalNames[0]) + "-(.*):";
|
||||
Pattern pattern = Pattern.compile(regexPattern);
|
||||
Matcher matcher = pattern.matcher(evalInfo);
|
||||
|
||||
String metricName = null;
|
||||
if (matcher.find()) {
|
||||
metricName = matcher.group(1);
|
||||
logger.debug("Got the metric name: " + metricName);
|
||||
}
|
||||
return metricName;
|
||||
}
|
||||
|
||||
// visiable for testing
|
||||
public static boolean isMaximizeEvaluation(String evalInfo,
|
||||
String[] evalNames,
|
||||
Map<String, Object> params) {
|
||||
|
||||
String metricName;
|
||||
|
||||
if (params.get("maximize_evaluation_metrics") != null) {
|
||||
// user has forced the direction no matter what is the metric name.
|
||||
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
|
||||
return Boolean.valueOf(maximize);
|
||||
}
|
||||
|
||||
if (params.get("eval_metric") != null) {
|
||||
// user has special metric name
|
||||
metricName = String.valueOf(params.get("eval_metric"));
|
||||
} else {
|
||||
// infer the metric name from log
|
||||
metricName = getMetricNameFromlog(evalInfo, evalNames);
|
||||
}
|
||||
|
||||
assert metricName != null;
|
||||
|
||||
if (!"mape".equals(metricName)) {
|
||||
for (String x : MAXIMIZ_METRICES) {
|
||||
if (metricName.startsWith(x)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -110,7 +110,7 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
||||
|
||||
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad,
|
||||
public final static native int XGBoosterTrainOneIter(long handle, long dtrain, int iter, float[] grad,
|
||||
float[] hess);
|
||||
|
||||
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats,
|
||||
@@ -164,4 +164,8 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
String featureJson, float missing, int nthread, long[] out);
|
||||
|
||||
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||
|
||||
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||
|
||||
}
|
||||
|
||||
@@ -106,27 +106,41 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
booster.update(dtrain.jDMatrix, iter)
|
||||
}
|
||||
|
||||
@throws(classOf[XGBoostError])
|
||||
@deprecated
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
}
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration
|
||||
* @param obj customized objective class
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
def update(dtrain: DMatrix, iter: Int, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, iter, obj)
|
||||
}
|
||||
|
||||
@throws(classOf[XGBoostError])
|
||||
@deprecated
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
def boost(dtrain: DMatrix, iter: Int, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, iter, grad, hess)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -205,6 +205,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
jDMatrix.setBaseMargin(column)
|
||||
}
|
||||
|
||||
/**
|
||||
* set feature names
|
||||
* @param values feature names
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setFeatureNames(values: Array[String]): Unit = {
|
||||
jDMatrix.setFeatureNames(values)
|
||||
}
|
||||
|
||||
/**
|
||||
* set feature types
|
||||
* @param values feature types
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setFeatureTypes(values: Array[String]): Unit = {
|
||||
jDMatrix.setFeatureTypes(values)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get group sizes of DMatrix (used for ranking)
|
||||
*/
|
||||
@@ -243,6 +263,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
jDMatrix.getBaseMargin
|
||||
}
|
||||
|
||||
/**
|
||||
* get feature names
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
* @return
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureNames: Array[String] = {
|
||||
jDMatrix.getFeatureNames
|
||||
}
|
||||
|
||||
/**
|
||||
* get feature types
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
* @return
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureTypes: Array[String] = {
|
||||
jDMatrix.getFeatureTypes
|
||||
}
|
||||
|
||||
/**
|
||||
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
||||
*
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/c_api/c_api_error.h"
|
||||
#include "../../../src/c_api/c_api_utils.h"
|
||||
|
||||
#define JVM_CHECK_CALL(__expr) \
|
||||
@@ -579,22 +580,44 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterBoostOneIter
|
||||
* Signature: (JJ[F[F)V
|
||||
* Method: XGBoosterTrainOneIter
|
||||
* Signature: (JJI[F[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
||||
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
|
||||
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
|
||||
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
|
||||
JVM_CHECK_CALL(ret);
|
||||
//release
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter(
|
||||
JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad,
|
||||
jfloatArray jhess) {
|
||||
API_BEGIN();
|
||||
BoosterHandle handle = reinterpret_cast<BoosterHandle *>(jhandle);
|
||||
DMatrixHandle dtrain = reinterpret_cast<DMatrixHandle *>(jdtrain);
|
||||
CHECK(handle);
|
||||
CHECK(dtrain);
|
||||
bst_ulong n_samples{0};
|
||||
JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples));
|
||||
|
||||
bst_ulong len = static_cast<bst_ulong>(jenv->GetArrayLength(jgrad));
|
||||
jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr);
|
||||
jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr);
|
||||
CHECK(grad);
|
||||
CHECK(hess);
|
||||
|
||||
xgboost::bst_target_t n_targets{1};
|
||||
if (len != n_samples && n_samples != 0) {
|
||||
CHECK_EQ(len % n_samples, 0) << "Invalid size of gradient.";
|
||||
n_targets = len / n_samples;
|
||||
}
|
||||
|
||||
auto ctx = xgboost::detail::BoosterCtx(handle);
|
||||
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface(
|
||||
ctx, grad, hess, xgboost::linalg::kC, n_samples, n_targets);
|
||||
int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast<std::int32_t>(jiter), s_grad.c_str(),
|
||||
s_hess.c_str());
|
||||
|
||||
// release
|
||||
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
|
||||
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
|
||||
|
||||
return ret;
|
||||
API_END();
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -1148,3 +1171,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFea
|
||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo(
|
||||
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
|
||||
jobjectArray jfeatures) {
|
||||
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||
|
||||
const char *field = jenv->GetStringUTFChars(jfield, 0);
|
||||
|
||||
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeatures);
|
||||
|
||||
std::vector<std::string> features;
|
||||
std::vector<char const*> features_char;
|
||||
|
||||
for (bst_ulong i = 0; i < feature_num; ++i) {
|
||||
jstring jfeature = (jstring)jenv->GetObjectArrayElement(jfeatures, i);
|
||||
const char *s = jenv->GetStringUTFChars(jfeature, 0);
|
||||
features.push_back(std::string(s, jenv->GetStringLength(jfeature)));
|
||||
if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature, s);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < features.size(); ++i) {
|
||||
features_char.push_back(features[i].c_str());
|
||||
}
|
||||
|
||||
int ret = XGBoosterSetStrFeatureInfo(
|
||||
handle, field, dmlc::BeginPtr(features_char), feature_num);
|
||||
JVM_CHECK_CALL(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetGtrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
|
||||
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
|
||||
jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||
|
||||
const char *field = jenv->GetStringUTFChars(jfield, 0);
|
||||
|
||||
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jout);
|
||||
|
||||
const char **features;
|
||||
std::vector<char *> features_char;
|
||||
|
||||
int ret = XGBoosterGetStrFeatureInfo(handle, field, &feature_num,
|
||||
(const char ***)&features);
|
||||
JVM_CHECK_CALL(ret);
|
||||
|
||||
for (bst_ulong i = 0; i < feature_num; i++) {
|
||||
jstring jfeature = jenv->NewStringUTF(features[i]);
|
||||
jenv->SetObjectArrayElement(jout, i, jfeature);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -185,11 +185,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterBoostOneIter
|
||||
* Signature: (JJ[F[F)I
|
||||
* Method: XGBoosterTrainOneIter
|
||||
* Signature: (JJI[F[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
|
||||
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter
|
||||
(JNIEnv *, jclass, jlong, jlong, jint, jfloatArray, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@@ -383,6 +383,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixC
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -16,10 +16,7 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Test;
|
||||
@@ -122,6 +119,39 @@ public class BoosterImplTest {
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithFeaturesWithPath() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
String[] featureNames = new String[126];
|
||||
String[] featureTypes = new String[126];
|
||||
for(int i = 0; i < 126; i++) {
|
||||
featureNames[i] = "test_feature_name_" + i;
|
||||
featureTypes[i] = "q";
|
||||
}
|
||||
trainMat.setFeatureNames(featureNames);
|
||||
testMat.setFeatureNames(featureNames);
|
||||
trainMat.setFeatureTypes(featureTypes);
|
||||
testMat.setFeatureTypes(featureTypes);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
// save and load, only json format save and load feature_name and feature_type
|
||||
File temp = File.createTempFile("temp", ".json");
|
||||
temp.deleteOnExit();
|
||||
booster.saveModel(temp.getAbsolutePath());
|
||||
|
||||
String modelString = new String(booster.toByteArray("json"));
|
||||
|
||||
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
||||
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
|
||||
assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json")));
|
||||
assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated")));
|
||||
float[][] predicts2 = bst2.predict(testMat, true, 0);
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithStream() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
/*
|
||||
Copyright (c) 2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
public class XGBoostTest {
|
||||
|
||||
private String composeEvalInfo(String metric, String evalName) {
|
||||
return "[0]\t" + evalName + "-" + metric + ":" + "\ttest";
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsMaximizeEvaluation() {
|
||||
String[] minimum_metrics = {"mape", "logloss", "error", "others"};
|
||||
String[] evalNames = {"set-abc"};
|
||||
|
||||
HashMap<String, Object> params = new HashMap<>();
|
||||
|
||||
// test1, infer the metric from faked log
|
||||
for (String x : XGBoost.MAXIMIZ_METRICES) {
|
||||
String evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
}
|
||||
|
||||
// test2, the direction for mape should be minimum
|
||||
String evalInfo = composeEvalInfo("mape", evalNames[0]);
|
||||
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
|
||||
// test3, force maximize_evaluation_metrics
|
||||
params.clear();
|
||||
params.put("maximize_evaluation_metrics", true);
|
||||
// auc should be max,
|
||||
evalInfo = composeEvalInfo("auc", evalNames[0]);
|
||||
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
|
||||
params.clear();
|
||||
params.put("maximize_evaluation_metrics", false);
|
||||
// auc should be min,
|
||||
evalInfo = composeEvalInfo("auc", evalNames[0]);
|
||||
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
|
||||
// test4, set the metric manually
|
||||
for (String x : XGBoost.MAXIMIZ_METRICES) {
|
||||
params.clear();
|
||||
params.put("eval_metric", x);
|
||||
evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||
TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
}
|
||||
|
||||
// test5, set the metric manually
|
||||
for (String x : minimum_metrics) {
|
||||
params.clear();
|
||||
params.put("eval_metric", x);
|
||||
evalInfo = composeEvalInfo(x, evalNames[0]);
|
||||
TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEarlyStop() throws XGBoostError {
|
||||
Random random = new Random(1);
|
||||
|
||||
java.util.ArrayList<Float> labelall = new java.util.ArrayList<Float>();
|
||||
int nrep = 3000;
|
||||
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
|
||||
for (int i = 0; i < nrep; ++i) {
|
||||
LabeledPoint p = new LabeledPoint(
|
||||
i % 2, 4,
|
||||
new int[]{0, 1, 2, 3},
|
||||
new float[]{random.nextFloat(), random.nextFloat(), random.nextFloat(), random.nextFloat()});
|
||||
blist.add(p);
|
||||
labelall.add(p.label());
|
||||
}
|
||||
|
||||
DMatrix dmat = new DMatrix(blist.iterator(), null);
|
||||
|
||||
int round = 50;
|
||||
int earlyStop = 2;
|
||||
|
||||
HashMap<String, Object> mapParams = new HashMap<>();
|
||||
mapParams.put("eta", 0.1);
|
||||
mapParams.put("objective", "binary:logistic");
|
||||
mapParams.put("max_depth", 3);
|
||||
mapParams.put("eval_metric", "auc");
|
||||
mapParams.put("silent", 0);
|
||||
|
||||
HashMap<String, DMatrix> mapWatches = new HashMap<>();
|
||||
mapWatches.put("selTrain-*", dmat);
|
||||
|
||||
try {
|
||||
Booster booster = XGBoost.train(dmat, mapParams, round, mapWatches, null, null, null, earlyStop);
|
||||
Map<String, String> attrs = booster.getAttrs();
|
||||
TestCase.assertTrue(Integer.valueOf(attrs.get("best_iteration")) < round - 1);
|
||||
} catch (Exception e) {
|
||||
TestCase.assertFalse(false);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user