Compare commits
26 Commits
v1.7.3
...
release_1.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea6b117a57 | ||
|
|
1830a5c5cb | ||
|
|
163149cb10 | ||
|
|
40b4a45770 | ||
|
|
d83db4844b | ||
|
|
3550b16a34 | ||
|
|
917b0a7b46 | ||
|
|
58ebbab979 | ||
|
|
2bc5d8d449 | ||
|
|
7d178cbd25 | ||
|
|
74e2f652de | ||
|
|
e02fff53f2 | ||
|
|
fcb2efbadd | ||
|
|
f4621f09c7 | ||
|
|
bf1b2cbfa2 | ||
|
|
d90e7b3117 | ||
|
|
088c43d666 | ||
|
|
69fc8a632f | ||
|
|
213f4fa45a | ||
|
|
5ca21f252a | ||
|
|
eeb67c3d52 | ||
|
|
ed37fdb9c9 | ||
|
|
e7e522fb06 | ||
|
|
8e39a675be | ||
|
|
7f542d2198 | ||
|
|
c8d32102fb |
@@ -1,5 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(xgboost LANGUAGES CXX C VERSION 1.0.0)
|
||||
project(xgboost LANGUAGES CXX C VERSION 1.0.2)
|
||||
include(cmake/Utils.cmake)
|
||||
list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules")
|
||||
cmake_policy(SET CMP0022 NEW)
|
||||
@@ -49,7 +49,7 @@ option(USE_SANITIZER "Use santizer flags" OFF)
|
||||
option(SANITIZER_PATH "Path to sanitizes.")
|
||||
set(ENABLED_SANITIZERS "address" "leak" CACHE STRING
|
||||
"Semicolon separated list of sanitizer names. E.g 'address;leak'. Supported sanitizers are
|
||||
address, leak and thread.")
|
||||
address, leak, undefined and thread.")
|
||||
## Plugins
|
||||
option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
|
||||
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
|
||||
|
||||
27
Jenkinsfile
vendored
27
Jenkinsfile
vendored
@@ -95,6 +95,17 @@ pipeline {
|
||||
milestone ordinal: 4
|
||||
}
|
||||
}
|
||||
stage('Jenkins Linux: Deploy') {
|
||||
agent none
|
||||
steps {
|
||||
script {
|
||||
parallel ([
|
||||
'deploy-jvm-packages': { DeployJVMPackages(spark_version: '2.4.3') }
|
||||
])
|
||||
}
|
||||
milestone ordinal: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,6 +284,7 @@ def TestPythonCPU() {
|
||||
def docker_binary = "docker"
|
||||
sh """
|
||||
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/test_python.sh cpu
|
||||
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/test_python.sh cpu-py35
|
||||
"""
|
||||
deleteDir()
|
||||
}
|
||||
@@ -379,3 +391,18 @@ def TestR(args) {
|
||||
deleteDir()
|
||||
}
|
||||
}
|
||||
|
||||
def DeployJVMPackages(args) {
|
||||
node('linux && cpu') {
|
||||
unstash name: 'srcs'
|
||||
if (env.BRANCH_NAME == 'master' || env.BRANCH_NAME.startsWith('release')) {
|
||||
echo 'Deploying to xgboost-maven-repo S3 repo...'
|
||||
def container_type = "jvm"
|
||||
def docker_binary = "docker"
|
||||
sh """
|
||||
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/deploy_jvm_packages.sh ${args.spark_version}
|
||||
"""
|
||||
}
|
||||
deleteDir()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,6 +139,8 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
|
||||
#' prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
|
||||
#' or predinteraction flags is TRUE.
|
||||
#' @param training whether is the prediction result used for training. For dart booster,
|
||||
#' training predicting will perform dropout.
|
||||
#' @param ... Parameters passed to \code{predict.xgb.Booster}
|
||||
#'
|
||||
#' @details
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
\name{agaricus.test}
|
||||
\alias{agaricus.test}
|
||||
\title{Test part from Mushroom Data Set}
|
||||
\format{A list containing a label vector, and a dgCMatrix object with 1611
|
||||
\format{A list containing a label vector, and a dgCMatrix object with 1611
|
||||
rows and 126 variables}
|
||||
\usage{
|
||||
data(agaricus.test)
|
||||
@@ -24,8 +24,8 @@ This data set includes the following fields:
|
||||
\references{
|
||||
https://archive.ics.uci.edu/ml/datasets/Mushroom
|
||||
|
||||
Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
||||
[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
||||
Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
||||
[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
||||
School of Information and Computer Science.
|
||||
}
|
||||
\keyword{datasets}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
\name{agaricus.train}
|
||||
\alias{agaricus.train}
|
||||
\title{Training part from Mushroom Data Set}
|
||||
\format{A list containing a label vector, and a dgCMatrix object with 6513
|
||||
\format{A list containing a label vector, and a dgCMatrix object with 6513
|
||||
rows and 127 variables}
|
||||
\usage{
|
||||
data(agaricus.train)
|
||||
@@ -24,8 +24,8 @@ This data set includes the following fields:
|
||||
\references{
|
||||
https://archive.ics.uci.edu/ml/datasets/Mushroom
|
||||
|
||||
Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
||||
[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
||||
Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
||||
[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
||||
School of Information and Computer Science.
|
||||
}
|
||||
\keyword{datasets}
|
||||
|
||||
@@ -49,6 +49,9 @@ It will use all the trees by default (\code{NULL} value).}
|
||||
prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
|
||||
or predinteraction flags is TRUE.}
|
||||
|
||||
\item{training}{whether is the prediction result used for training. For dart booster,
|
||||
training predicting will perform dropout.}
|
||||
|
||||
\item{...}{Parameters passed to \code{predict.xgb.Booster}}
|
||||
}
|
||||
\value{
|
||||
|
||||
@@ -31,7 +31,6 @@ num_round <- 2
|
||||
test_that("custom objective works", {
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||
expect_equal(class(bst), "xgb.Booster")
|
||||
expect_equal(length(bst$raw), 1100)
|
||||
expect_false(is.null(bst$evaluation_log))
|
||||
expect_false(is.null(bst$evaluation_log$eval_error))
|
||||
expect_lt(bst$evaluation_log[num_round, eval_error], 0.03)
|
||||
@@ -58,5 +57,4 @@ test_that("custom objective using DMatrix attr works", {
|
||||
param$objective = logregobjattr
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||
expect_equal(class(bst), "xgb.Booster")
|
||||
expect_equal(length(bst$raw), 1100)
|
||||
})
|
||||
|
||||
@@ -7,8 +7,8 @@ require(vcd, quietly = TRUE)
|
||||
|
||||
float_tolerance = 5e-6
|
||||
|
||||
# disable some tests for Win32
|
||||
win32_flag = .Platform$OS.type == "windows" && .Machine$sizeof.pointer != 8
|
||||
# disable some tests for 32-bit environment
|
||||
flag_32bit = .Machine$sizeof.pointer != 8
|
||||
|
||||
set.seed(1982)
|
||||
data(Arthritis)
|
||||
@@ -44,7 +44,7 @@ mbst.GLM <- xgboost(data = as.matrix(iris[, -5]), label = mlabel, verbose = 0,
|
||||
|
||||
|
||||
test_that("xgb.dump works", {
|
||||
if (!win32_flag)
|
||||
if (!flag_32bit)
|
||||
expect_length(xgb.dump(bst.Tree), 200)
|
||||
dump_file = file.path(tempdir(), 'xgb.model.dump')
|
||||
expect_true(xgb.dump(bst.Tree, dump_file, with_stats = T))
|
||||
@@ -54,7 +54,7 @@ test_that("xgb.dump works", {
|
||||
# JSON format
|
||||
dmp <- xgb.dump(bst.Tree, dump_format = "json")
|
||||
expect_length(dmp, 1)
|
||||
if (!win32_flag)
|
||||
if (!flag_32bit)
|
||||
expect_length(grep('nodeid', strsplit(dmp, '\n')[[1]]), 188)
|
||||
})
|
||||
|
||||
@@ -256,7 +256,7 @@ test_that("xgb.model.dt.tree works with and without feature names", {
|
||||
names.dt.trees <- c("Tree", "Node", "ID", "Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover")
|
||||
dt.tree <- xgb.model.dt.tree(feature_names = feature.names, model = bst.Tree)
|
||||
expect_equal(names.dt.trees, names(dt.tree))
|
||||
if (!win32_flag)
|
||||
if (!flag_32bit)
|
||||
expect_equal(dim(dt.tree), c(188, 10))
|
||||
expect_output(str(dt.tree), 'Feature.*\\"Age\\"')
|
||||
|
||||
@@ -283,7 +283,7 @@ test_that("xgb.model.dt.tree throws error for gblinear", {
|
||||
|
||||
test_that("xgb.importance works with and without feature names", {
|
||||
importance.Tree <- xgb.importance(feature_names = feature.names, model = bst.Tree)
|
||||
if (!win32_flag)
|
||||
if (!flag_32bit)
|
||||
expect_equal(dim(importance.Tree), c(7, 4))
|
||||
expect_equal(colnames(importance.Tree), c("Feature", "Gain", "Cover", "Frequency"))
|
||||
expect_output(str(importance.Tree), 'Feature.*\\"Age\\"')
|
||||
|
||||
@@ -1 +1 @@
|
||||
@xgboost_VERSION_MAJOR@.@xgboost_VERSION_MINOR@.@xgboost_VERSION_PATCH@-SNAPSHOT
|
||||
@xgboost_VERSION_MAJOR@.@xgboost_VERSION_MINOR@.@xgboost_VERSION_PATCH@
|
||||
|
||||
@@ -8,15 +8,143 @@ XGBoost JVM Package
|
||||
<img alt="Build Status" src="https://travis-ci.org/dmlc/xgboost.svg?branch=master">
|
||||
</a>
|
||||
<a href="https://github.com/dmlc/xgboost/blob/master/LICENSE">
|
||||
<img alt="GitHub license" src="http://dmlc.github.io/img/apache2.svg">
|
||||
<img alt="GitHub license" src="https://dmlc.github.io/img/apache2.svg">
|
||||
</a>
|
||||
|
||||
You have found the XGBoost JVM Package!
|
||||
|
||||
.. _install_jvm_packages:
|
||||
|
||||
************
|
||||
Installation
|
||||
************
|
||||
|
||||
.. contents::
|
||||
:local:
|
||||
:backlinks: none
|
||||
|
||||
Installation from Maven repository
|
||||
==================================
|
||||
|
||||
Access release version
|
||||
----------------------
|
||||
You can use XGBoost4J in your Java/Scala application by adding XGBoost4J as a dependency:
|
||||
|
||||
.. code-block:: xml
|
||||
:caption: Maven
|
||||
|
||||
<properties>
|
||||
...
|
||||
<!-- Specify Scala version in package name -->
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
...
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
.. code-block:: scala
|
||||
:caption: sbt
|
||||
|
||||
libraryDependencies ++= Seq(
|
||||
"ml.dmlc" %% "xgboost4j" % "latest_version_num",
|
||||
"ml.dmlc" %% "xgboost4j-spark" % "latest_version_num"
|
||||
)
|
||||
|
||||
This will check out the latest stable version from the Maven Central.
|
||||
|
||||
For the latest release version number, please check `here <https://github.com/dmlc/xgboost/releases>`_.
|
||||
|
||||
.. note:: Using Maven repository hosted by the XGBoost project
|
||||
|
||||
There may be some delay until a new release becomes available to Maven Central. If you would like to access the latest release immediately, add the Maven repository hosted by the XGBoost project:
|
||||
|
||||
.. code-block:: xml
|
||||
:caption: Maven
|
||||
|
||||
<repository>
|
||||
<id>XGBoost4J Release Repo</id>
|
||||
<name>XGBoost4J Release Repo</name>
|
||||
<url>https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/release/</url>
|
||||
</repository>
|
||||
|
||||
.. code-block:: scala
|
||||
:caption: sbt
|
||||
|
||||
resolvers += "XGBoost4J Release Repo" at "https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/release/"
|
||||
|
||||
Access SNAPSHOT version
|
||||
-----------------------
|
||||
|
||||
First add the following Maven repository hosted by the XGBoost project:
|
||||
|
||||
.. code-block:: xml
|
||||
:caption: Maven
|
||||
|
||||
<repository>
|
||||
<id>XGBoost4J Snapshot Repo</id>
|
||||
<name>XGBoost4J Snapshot Repo</name>
|
||||
<url>https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/</url>
|
||||
</repository>
|
||||
|
||||
.. code-block:: scala
|
||||
:caption: sbt
|
||||
|
||||
resolvers += "XGBoost4J Snapshot Repo" at "https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/"
|
||||
|
||||
Then add XGBoost4J as a dependency:
|
||||
|
||||
.. code-block:: xml
|
||||
:caption: maven
|
||||
|
||||
<properties>
|
||||
...
|
||||
<!-- Specify Scala version in package name -->
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
...
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>latest_version_num-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||
<version>latest_version_num-SNAPSHOT</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
.. code-block:: scala
|
||||
:caption: sbt
|
||||
|
||||
libraryDependencies ++= Seq(
|
||||
"ml.dmlc" %% "xgboost4j" % "latest_version_num-SNAPSHOT",
|
||||
"ml.dmlc" %% "xgboost4j-spark" % "latest_version_num-SNAPSHOT"
|
||||
)
|
||||
|
||||
Look up the ``version`` field in `pom.xml <https://github.com/dmlc/xgboost/blob/master/jvm-packages/pom.xml>`_ to get the correct version number.
|
||||
|
||||
The SNAPSHOT JARs are hosted by the XGBoost project. Every commit in the ``master`` branch will automatically trigger generation of a new SNAPSHOT JAR. You can control how often Maven should upgrade your SNAPSHOT installation by specifying ``updatePolicy``. See `here <http://maven.apache.org/pom.html#Repositories>`_ for details.
|
||||
|
||||
You can browse the file listing of the Maven repository at https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/list.html.
|
||||
|
||||
.. note:: Windows not supported by published JARs
|
||||
|
||||
The published JARs from the Maven Central and GitHub currently only supports Linux and MacOS. Windows users should consider building XGBoost4J / XGBoost4J-Spark from the source. Alternatively, checkout pre-built JARs from `criteo-forks/xgboost-jars <https://github.com/criteo-forks/xgboost-jars>`_.
|
||||
|
||||
Installation from source
|
||||
========================
|
||||
|
||||
@@ -64,73 +192,6 @@ If you want to use XGBoost4J-Spark, replace ``xgboost4j`` with ``xgboost4j-spark
|
||||
|
||||
Also, make sure to install Spark directly from `Apache website <https://spark.apache.org/>`_. **Upstream XGBoost is not guaranteed to work with third-party distributions of Spark, such as Cloudera Spark.** Consult appropriate third parties to obtain their distribution of XGBoost.
|
||||
|
||||
Installation from maven repo
|
||||
============================
|
||||
|
||||
Access release version
|
||||
----------------------
|
||||
|
||||
.. code-block:: xml
|
||||
:caption: maven
|
||||
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
|
||||
.. code-block:: scala
|
||||
:caption: sbt
|
||||
|
||||
"ml.dmlc" % "xgboost4j" % "latest_version_num"
|
||||
|
||||
This will checkout the latest stable version from the Maven Central.
|
||||
|
||||
For the latest release version number, please check `here <https://github.com/dmlc/xgboost/releases>`_.
|
||||
|
||||
if you want to use XGBoost4J-Spark, replace ``xgboost4j`` with ``xgboost4j-spark``.
|
||||
|
||||
Access SNAPSHOT version
|
||||
-----------------------
|
||||
|
||||
You need to add GitHub as repo:
|
||||
|
||||
.. code-block:: xml
|
||||
:caption: maven
|
||||
|
||||
<repository>
|
||||
<id>GitHub Repo</id>
|
||||
<name>GitHub Repo</name>
|
||||
<url>https://raw.githubusercontent.com/CodingCat/xgboost/maven-repo/</url>
|
||||
</repository>
|
||||
|
||||
.. code-block:: scala
|
||||
:caption: sbt
|
||||
|
||||
resolvers += "GitHub Repo" at "https://raw.githubusercontent.com/CodingCat/xgboost/maven-repo/"
|
||||
|
||||
Then add dependency as following:
|
||||
|
||||
.. code-block:: xml
|
||||
:caption: maven
|
||||
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
|
||||
.. code-block:: scala
|
||||
:caption: sbt
|
||||
|
||||
"ml.dmlc" % "xgboost4j" % "latest_version_num"
|
||||
|
||||
For the latest release version number, please check `here <https://github.com/CodingCat/xgboost/tree/maven-repo/ml/dmlc/xgboost4j>`_.
|
||||
|
||||
.. note:: Windows not supported by published JARs
|
||||
|
||||
The published JARs from the Maven Central and GitHub currently only supports Linux and MacOS. Windows users should consider building XGBoost4J / XGBoost4J-Spark from the source. Alternatively, checkout pre-built JARs from `criteo-forks/xgboost-jars <https://github.com/criteo-forks/xgboost-jars>`_.
|
||||
|
||||
Enabling OpenMP for Mac OS
|
||||
--------------------------
|
||||
If you are on Mac OS and using a compiler that supports OpenMP, you need to go to the file ``xgboost/jvm-packages/create_jni.py`` and comment out the line
|
||||
|
||||
@@ -27,39 +27,7 @@ Build an ML Application with XGBoost4J-Spark
|
||||
Refer to XGBoost4J-Spark Dependency
|
||||
===================================
|
||||
|
||||
Before we go into the tour of how to use XGBoost4J-Spark, we would bring a brief introduction about how to build a machine learning application with XGBoost4J-Spark. The first thing you need to do is to refer to the dependency in Maven Central.
|
||||
|
||||
You can add the following dependency in your ``pom.xml``.
|
||||
|
||||
.. code-block:: xml
|
||||
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark</artifactId>
|
||||
<version>latest_version_num</version>
|
||||
</dependency>
|
||||
|
||||
For the latest release version number, please check `here <https://github.com/dmlc/xgboost/releases>`_.
|
||||
|
||||
We also publish some functionalities which would be included in the coming release in the form of snapshot version. To access these functionalities, you can add dependency to the snapshot artifacts. We publish snapshot version in github-based repo, so you can add the following repo in ``pom.xml``:
|
||||
|
||||
.. code-block:: xml
|
||||
|
||||
<repository>
|
||||
<id>XGBoost4J-Spark Snapshot Repo</id>
|
||||
<name>XGBoost4J-Spark Snapshot Repo</name>
|
||||
<url>https://raw.githubusercontent.com/CodingCat/xgboost/maven-repo/</url>
|
||||
</repository>
|
||||
|
||||
and then refer to the snapshot dependency by adding:
|
||||
|
||||
.. code-block:: xml
|
||||
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark</artifactId>
|
||||
<version>next_version_num-SNAPSHOT</version>
|
||||
</dependency>
|
||||
Before we go into the tour of how to use XGBoost4J-Spark, you should first consult :ref:`Installation from Maven repository <install_jvm_packages>` in order to add XGBoost4J-Spark as a dependency for your project. We provide both stable releases and snapshots.
|
||||
|
||||
.. note:: XGBoost4J-Spark requires Apache Spark 2.4+
|
||||
|
||||
|
||||
@@ -195,12 +195,22 @@
|
||||
"properties": {
|
||||
"version": {
|
||||
"type": "array",
|
||||
"const": [
|
||||
1,
|
||||
0,
|
||||
0
|
||||
"items": [
|
||||
{
|
||||
"type": "number",
|
||||
"const": 1
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"minimum": 0
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"minimum": 0
|
||||
}
|
||||
],
|
||||
"additionalItems": false
|
||||
"minItems": 3,
|
||||
"maxItems": 3
|
||||
},
|
||||
"learner": {
|
||||
"type": "object",
|
||||
|
||||
79
doc/python/convert_090to100.py
Normal file
79
doc/python/convert_090to100.py
Normal file
@@ -0,0 +1,79 @@
|
||||
'''This is a simple script that converts a pickled XGBoost
|
||||
Scikit-Learn interface object from 0.90 to a native model. Pickle
|
||||
format is not stable as it's a direct serialization of Python object.
|
||||
We advice not to use it when stability is needed.
|
||||
|
||||
'''
|
||||
import pickle
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import xgboost
|
||||
import warnings
|
||||
|
||||
|
||||
def save_label_encoder(le):
|
||||
'''Save the label encoder in XGBClassifier'''
|
||||
meta = dict()
|
||||
for k, v in le.__dict__.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
meta[k] = v.tolist()
|
||||
else:
|
||||
meta[k] = v
|
||||
return meta
|
||||
|
||||
|
||||
def xgboost_skl_90to100(skl_model):
|
||||
'''Extract the model and related metadata in SKL model.'''
|
||||
model = {}
|
||||
with open(skl_model, 'rb') as fd:
|
||||
old = pickle.load(fd)
|
||||
if not isinstance(old, xgboost.XGBModel):
|
||||
raise TypeError(
|
||||
'The script only handes Scikit-Learn interface object')
|
||||
|
||||
# Save Scikit-Learn specific Python attributes into a JSON document.
|
||||
for k, v in old.__dict__.items():
|
||||
if k == '_le':
|
||||
model[k] = save_label_encoder(v)
|
||||
elif k == 'classes_':
|
||||
model[k] = v.tolist()
|
||||
elif k == '_Booster':
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
json.dumps({k: v})
|
||||
model[k] = v
|
||||
except TypeError:
|
||||
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
||||
booster = old.get_booster()
|
||||
# Store the JSON serialization as an attribute
|
||||
booster.set_attr(scikit_learn=json.dumps(model))
|
||||
|
||||
# Save it into a native model.
|
||||
i = 0
|
||||
while True:
|
||||
path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin'
|
||||
if os.path.exists(path):
|
||||
i += 1
|
||||
continue
|
||||
booster.save_model(path)
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version'
|
||||
' that generates this pickle.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description=('A simple script to convert pickle generated by'
|
||||
' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).'))
|
||||
parser.add_argument(
|
||||
'--old-pickle',
|
||||
type=str,
|
||||
help='Path to old pickle file of Scikit-Learn interface object. '
|
||||
'Will output a native model converted from this pickle file',
|
||||
required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
xgboost_skl_90to100(args.old_pickle)
|
||||
@@ -91,7 +91,12 @@ Loading pickled file from different version of XGBoost
|
||||
|
||||
As noted, pickled model is neither portable nor stable, but in some cases the pickled
|
||||
models are valuable. One way to restore it in the future is to load it back with that
|
||||
specific version of Python and XGBoost, export the model by calling `save_model`.
|
||||
specific version of Python and XGBoost, export the model by calling `save_model`. To help
|
||||
easing the mitigation, we created a simple script for converting pickled XGBoost 0.90
|
||||
Scikit-Learn interface object to XGBoost 1.0.0 native model. Please note that the script
|
||||
suits simple use cases, and it's advised not to use pickle when stability is needed.
|
||||
It's located in ``xgboost/doc/python`` with the name ``convert_090to100.py``. See
|
||||
comments in the script for more details.
|
||||
|
||||
********************************************************
|
||||
Saving and Loading the internal parameters configuration
|
||||
@@ -190,7 +195,9 @@ You can load it back to the model generated by same version of XGBoost by:
|
||||
|
||||
bst.load_config(config)
|
||||
|
||||
This way users can study the internal representation more closely.
|
||||
This way users can study the internal representation more closely. Please note that some
|
||||
JSON generators make use of locale dependent floating point serialization methods, which
|
||||
is not supported by XGBoost.
|
||||
|
||||
************
|
||||
Future Plans
|
||||
|
||||
@@ -208,6 +208,8 @@ struct LearnerModelParam {
|
||||
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
||||
// this one as an immutable copy.
|
||||
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin);
|
||||
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
|
||||
bool Initialized() const { return num_feature != 0; }
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
#define XGBOOST_VER_MAJOR 1
|
||||
#define XGBOOST_VER_MINOR 0
|
||||
#define XGBOOST_VER_PATCH 0
|
||||
#define XGBOOST_VER_PATCH 1
|
||||
|
||||
#endif // XGBOOST_VERSION_CONFIG_H_
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>XGBoost JVM Package</name>
|
||||
<description>JVM Package for XGBoost</description>
|
||||
@@ -37,6 +37,7 @@
|
||||
<spark.version>2.4.3</spark.version>
|
||||
<scala.version>2.12.8</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>2.7.3</hadoop.version>
|
||||
</properties>
|
||||
<repositories>
|
||||
<repository>
|
||||
@@ -204,6 +205,29 @@
|
||||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>release-to-s3</id>
|
||||
<distributionManagement>
|
||||
<snapshotRepository>
|
||||
<id>maven-s3-snapshot-repo</id>
|
||||
<url>s3://xgboost-maven-repo/snapshot</url>
|
||||
</snapshotRepository>
|
||||
<repository>
|
||||
<id>maven-s3-release-repo</id>
|
||||
<url>s3://xgboost-maven-repo/release</url>
|
||||
</repository>
|
||||
</distributionManagement>
|
||||
<repositories>
|
||||
<repository>
|
||||
<id>maven-s3-snapshot-repo</id>
|
||||
<url>https://s3.amazonaws.com/xgboost-maven-repo/snapshot</url>
|
||||
</repository>
|
||||
<repository>
|
||||
<id>maven-s3-release-repo</id>
|
||||
<url>https://s3.amazonaws.com/xgboost-maven-repo/release</url>
|
||||
</repository>
|
||||
</repositories>
|
||||
</profile>
|
||||
</profiles>
|
||||
<distributionManagement>
|
||||
<snapshotRepository>
|
||||
@@ -323,6 +347,13 @@
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
<extensions>
|
||||
<extension>
|
||||
<groupId>org.kuali.maven.wagons</groupId>
|
||||
<artifactId>maven-s3-wagon</artifactId>
|
||||
<version>1.2.1</version>
|
||||
</extension>
|
||||
</extensions>
|
||||
</build>
|
||||
<reporting>
|
||||
<plugins>
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-example_2.12</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
<packaging>jar</packaging>
|
||||
<build>
|
||||
<plugins>
|
||||
@@ -26,7 +26,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
@@ -37,7 +37,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-flink_2.12</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
@@ -26,7 +26,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-spark_2.12</artifactId>
|
||||
<build>
|
||||
@@ -24,7 +24,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
/**
|
||||
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
|
||||
* the job can restart training from a saved checkpoints instead of from scratch. This class
|
||||
* provides interface and helper methods for the checkpoint functionality.
|
||||
*
|
||||
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
|
||||
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
|
||||
* for every a few iterations.
|
||||
*
|
||||
* @param sc the sparkContext object
|
||||
* @param checkpointPath the hdfs path to store checkpoints
|
||||
*/
|
||||
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
|
||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||
private val modelSuffix = ".model"
|
||||
|
||||
private def getPath(version: Int) = {
|
||||
s"$checkpointPath/$version$modelSuffix"
|
||||
}
|
||||
|
||||
private def getExistingVersions: Seq[Int] = {
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
|
||||
Seq()
|
||||
} else {
|
||||
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
|
||||
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def cleanPath(): Unit = {
|
||||
if (checkpointPath != "") {
|
||||
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load existing checkpoint with the highest version as a Booster object
|
||||
*
|
||||
* @return the booster with the highest version, null if no checkpoints available.
|
||||
*/
|
||||
private[spark] def loadCheckpointAsBooster: Booster = {
|
||||
val versions = getExistingVersions
|
||||
if (versions.nonEmpty) {
|
||||
val version = versions.max
|
||||
val fullPath = getPath(version)
|
||||
val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
|
||||
logger.info(s"Start training from previous booster at $fullPath")
|
||||
val booster = SXGBoost.loadModel(inputStream)
|
||||
booster.booster.setVersion(version)
|
||||
booster
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up all previous checkpoints and save a new checkpoint
|
||||
*
|
||||
* @param checkpoint the checkpoint to save as an XGBoostModel
|
||||
*/
|
||||
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
||||
val fullPath = getPath(checkpoint.getVersion)
|
||||
val outputStream = fs.create(new Path(fullPath), true)
|
||||
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
|
||||
checkpoint.saveModel(outputStream)
|
||||
prevModelPaths.foreach(path => fs.delete(path, true))
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up checkpoint boosters with version higher than or equal to the round.
|
||||
*
|
||||
* @param round the number of rounds in the current training job
|
||||
*/
|
||||
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
|
||||
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
|
||||
higherVersions.foreach { version =>
|
||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||
fs.delete(new Path(getPath(version)), true)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
|
||||
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
|
||||
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
|
||||
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
|
||||
* and the method returns Seq(round)
|
||||
*
|
||||
* @param checkpointInterval Period (in iterations) between checkpoints.
|
||||
* @param round the total number of rounds for the training
|
||||
* @return a seq of integers, each represent the index of round to save the checkpoints
|
||||
*/
|
||||
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
|
||||
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
|
||||
val prevRounds = getExistingVersions.map(_ / 2)
|
||||
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
|
||||
(firstCheckpointRound until round by checkpointInterval) :+ round
|
||||
} else if (checkpointInterval <= 0) {
|
||||
Seq(round)
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object CheckpointManager {
|
||||
|
||||
case class CheckpointParam(
|
||||
checkpointPath: String,
|
||||
checkpointInterval: Int,
|
||||
skipCleanCheckpoint: Boolean)
|
||||
|
||||
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
|
||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||
case None => ""
|
||||
case Some(path: String) => path
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||
" an instance of String.")
|
||||
}
|
||||
|
||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||
case None => 0
|
||||
case Some(freq: Int) => freq
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
|
||||
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||
case None => false
|
||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||
" an instance of Boolean")
|
||||
}
|
||||
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
|
||||
}
|
||||
}
|
||||
@@ -25,12 +25,13 @@ import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.commons.io.FileUtils
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||
@@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see
|
||||
|
||||
private[this] case class XGBoostExecutionParams(
|
||||
numWorkers: Int,
|
||||
round: Int,
|
||||
numRounds: Int,
|
||||
useExternalMemory: Boolean,
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
@@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams(
|
||||
allowNonZeroForMissing: Boolean,
|
||||
trackerConf: TrackerConf,
|
||||
timeoutRequestWorkers: Long,
|
||||
checkpointParam: CheckpointParam,
|
||||
checkpointParam: Option[ExternalCheckpointParams],
|
||||
xgbInputParams: XGBoostExecutionInputParams,
|
||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||
cacheTrainingSet: Boolean) {
|
||||
@@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
.getOrElse("allow_non_zero_for_missing", false)
|
||||
.asInstanceOf[Boolean]
|
||||
validateSparkSslConf
|
||||
|
||||
if (overridedParams.contains("tree_method")) {
|
||||
require(overridedParams("tree_method") == "hist" ||
|
||||
overridedParams("tree_method") == "approx" ||
|
||||
@@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
" an instance of Long.")
|
||||
}
|
||||
val checkpointParam =
|
||||
CheckpointManager.extractParams(overridedParams)
|
||||
ExternalCheckpointParams.extractParams(overridedParams)
|
||||
|
||||
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
||||
.asInstanceOf[Double]
|
||||
@@ -339,11 +339,9 @@ object XGBoost extends Serializable {
|
||||
watches: Watches,
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
round: Int,
|
||||
obj: ObjectiveTrait,
|
||||
eval: EvalTrait,
|
||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||
|
||||
// to workaround the empty partitions in training dataset,
|
||||
// this might not be the best efficient implementation, see
|
||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||
@@ -357,14 +355,23 @@ object XGBoost extends Serializable {
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||
|
||||
val numRounds = xgbExecutionParam.numRounds
|
||||
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||
val booster = if (makeCheckpoint) {
|
||||
SXGBoost.trainAndSaveCheckpoint(
|
||||
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
|
||||
} else {
|
||||
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||
watches.toMap, metrics, obj, eval,
|
||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||
}
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
} catch {
|
||||
case xgbException: XGBoostError =>
|
||||
@@ -437,7 +444,6 @@ object XGBoost extends Serializable {
|
||||
trainingData: RDD[XGBLabeledPoint],
|
||||
xgbExecutionParams: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
@@ -446,8 +452,8 @@ object XGBoost extends Serializable {
|
||||
processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
||||
xgbExecutionParams.allowNonZeroForMissing),
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||
xgbExecutionParams.eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||
@@ -459,8 +465,8 @@ object XGBoost extends Serializable {
|
||||
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||
xgbExecutionParams.eval, prevBooster)
|
||||
}.cache()
|
||||
}
|
||||
}
|
||||
@@ -469,7 +475,6 @@ object XGBoost extends Serializable {
|
||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||
xgbExecutionParam: XGBoostExecutionParams,
|
||||
rabitEnv: java.util.Map[String, String],
|
||||
checkpointRound: Int,
|
||||
prevBooster: Booster,
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||
if (evalSetsMap.isEmpty) {
|
||||
@@ -478,7 +483,7 @@ object XGBoost extends Serializable {
|
||||
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
||||
xgbExecutionParam.allowNonZeroForMissing),
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||
}).cache()
|
||||
} else {
|
||||
@@ -490,7 +495,7 @@ object XGBoost extends Serializable {
|
||||
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
||||
},
|
||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||
xgbExecutionParam.obj,
|
||||
xgbExecutionParam.eval,
|
||||
prevBooster)
|
||||
@@ -529,60 +534,58 @@ object XGBoost extends Serializable {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
||||
checkpointPath)
|
||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
|
||||
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
||||
hasGroup, xgbExecParams.numWorkers)
|
||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
||||
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
checkpointParam.checkpointPath,
|
||||
FileSystem.get(sc.hadoopConfiguration))
|
||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
|
||||
checkpointManager.loadCheckpointAsScalaBooster()
|
||||
}.orNull
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val producedBooster = checkpointManager.getCheckpointRounds(
|
||||
xgbExecParams.checkpointParam.checkpointInterval,
|
||||
xgbExecParams.round).map {
|
||||
checkpointRound: Int =>
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
try {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||
xgbExecParams.timeoutRequestWorkers,
|
||||
xgbExecParams.numWorkers)
|
||||
|
||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams,
|
||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams,
|
||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boostersAndMetrics.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||
boostersAndMetrics, sparkJobThread)
|
||||
if (checkpointRound < xgbExecParams.round) {
|
||||
prevBooster = booster
|
||||
checkpointManager.updateCheckpoint(prevBooster)
|
||||
}
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
val (booster, metrics) = try {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||
xgbExecParams.timeoutRequestWorkers,
|
||||
xgbExecParams.numWorkers)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
|
||||
evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||
prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
// force the job
|
||||
boostersAndMetrics.foreachPartition(() => _)
|
||||
}
|
||||
}.last
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
|
||||
checkpointManager.cleanPath()
|
||||
}
|
||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkJobThread.start()
|
||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||
boostersAndMetrics, sparkJobThread)
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
producedBooster
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
xgbExecParams.checkpointParam.foreach {
|
||||
cpParam =>
|
||||
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
cpParam.checkpointPath,
|
||||
FileSystem.get(sc.hadoopConfiguration))
|
||||
checkpointManager.cleanPath()
|
||||
}
|
||||
}
|
||||
(booster, metrics)
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
// if the job was aborted due to an exception
|
||||
|
||||
@@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
|
||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
||||
|
||||
def needDeterministicRepartitioning: Boolean = {
|
||||
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
|
||||
import org.scalatest.{FunSuite, Ignore}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
private lazy val (model4, model8) = {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
|
||||
Map[String, Any] = {
|
||||
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
|
||||
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
|
||||
}
|
||||
|
||||
private def createNewModels():
|
||||
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val (model4, model8) = {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = produceParamMap(tmpPath, 2)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
}
|
||||
(tmpPath, model4, model8)
|
||||
}
|
||||
|
||||
test("test update/load models") {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
|
||||
manager.updateCheckpoint(model4._booster.booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
||||
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
|
||||
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
||||
}
|
||||
|
||||
test("test cleanUpHigherVersions") {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.cleanUpHigherVersions(round = 8)
|
||||
manager.cleanUpHigherVersions(8)
|
||||
assert(new File(s"$tmpPath/8.model").exists())
|
||||
|
||||
manager.cleanUpHigherVersions(round = 4)
|
||||
manager.cleanUpHigherVersions(4)
|
||||
assert(!new File(s"$tmpPath/8.model").exists())
|
||||
}
|
||||
|
||||
test("test checkpoint rounds") {
|
||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
|
||||
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
|
||||
import scala.collection.JavaConverters._
|
||||
val (tmpPath, model4, model8) = createNewModels()
|
||||
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||
assertResult(Seq(7))(
|
||||
manager.getCheckpointRounds(0, 7).asScala)
|
||||
assertResult(Seq(2, 4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
||||
assertResult(Seq(4, 6, 7))(
|
||||
manager.getCheckpointRounds(2, 7).asScala)
|
||||
}
|
||||
|
||||
|
||||
@@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||
|
||||
val paramMap = produceParamMap(tmpPath, 2)
|
||||
|
||||
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
||||
val skipCleanCheckpointMap =
|
||||
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
|
||||
skipCleanCheckpointMap
|
||||
|
||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
def error(model: Booster): Float = eval.eval(
|
||||
model.predict(testDM, outPutMargin = true), testDM)
|
||||
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
|
||||
|
||||
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
|
||||
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
|
||||
|
||||
if (skipCleanCheckpoint) {
|
||||
// Check only one model is kept after training
|
||||
@@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
// Train next model based on prev model
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) > error(prevModel._booster))
|
||||
assert(error(tmpModel) >= error(prevModel._booster))
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
} else {
|
||||
@@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
" stop the application") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
ss.sparkContext.setLogLevel("INFO")
|
||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||
// vector,
|
||||
val testDF = Seq(
|
||||
@@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
"does not stop application") {
|
||||
val spark = ss
|
||||
import spark.implicits._
|
||||
ss.sparkContext.setLogLevel("INFO")
|
||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||
// vector,
|
||||
val testDF = Seq(
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
|
||||
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
|
||||
|
||||
@@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.scalatest.{FunSuite, Ignore}
|
||||
|
||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
|
||||
@@ -6,13 +6,25 @@
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j_2.12</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<version>1.0.0</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-hdfs</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
public class ExternalCheckpointManager {
|
||||
|
||||
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||
private String modelSuffix = ".model";
|
||||
private Path checkpointPath;
|
||||
private FileSystem fs;
|
||||
|
||||
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||
if (checkpointPath == null || checkpointPath.isEmpty()) {
|
||||
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
|
||||
" empty checkpoint path");
|
||||
}
|
||||
this.checkpointPath = new Path(checkpointPath);
|
||||
this.fs = fs;
|
||||
}
|
||||
|
||||
private String getPath(int version) {
|
||||
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
|
||||
}
|
||||
|
||||
private List<Integer> getExistingVersions() throws IOException {
|
||||
if (!fs.exists(checkpointPath)) {
|
||||
return new ArrayList<>();
|
||||
} else {
|
||||
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||
.map(path -> path.getPath().getName())
|
||||
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||
.map(fileName -> Integer.valueOf(
|
||||
fileName.substring(0, fileName.length() - modelSuffix.length())))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
public void cleanPath() throws IOException {
|
||||
fs.delete(checkpointPath, true);
|
||||
}
|
||||
|
||||
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||
List<Integer> versions = getExistingVersions();
|
||||
if (versions.size() > 0) {
|
||||
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
||||
String checkpointPath = getPath(latestVersion);
|
||||
InputStream in = fs.open(new Path(checkpointPath));
|
||||
logger.info("loaded checkpoint from " + checkpointPath);
|
||||
Booster booster = XGBoost.loadModel(in);
|
||||
booster.setVersion(latestVersion);
|
||||
return booster;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||
List<String> prevModelPaths = getExistingVersions().stream()
|
||||
.map(this::getPath).collect(Collectors.toList());
|
||||
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
||||
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||
boosterToCheckpoint.saveModel(out);
|
||||
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
||||
prevModelPaths.stream().forEach(path -> {
|
||||
try {
|
||||
fs.delete(new Path(path), true);
|
||||
} catch (IOException e) {
|
||||
logger.error("failed to delete outdated checkpoint at " + path, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
||||
try {
|
||||
fs.delete(new Path(getPath(v)), true);
|
||||
} catch (IOException e) {
|
||||
logger.error("failed to clean checkpoint from other training instance", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
||||
throws IOException {
|
||||
if (checkpointInterval > 0) {
|
||||
List<Integer> prevRounds =
|
||||
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
||||
prevRounds.add(0);
|
||||
int firstCheckpointRound = prevRounds.stream()
|
||||
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
||||
List<Integer> arr = new ArrayList<>();
|
||||
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
||||
arr.add(i);
|
||||
}
|
||||
arr.add(numOfRounds);
|
||||
return arr;
|
||||
} else if (checkpointInterval <= 0) {
|
||||
List<Integer> l = new ArrayList<Integer>();
|
||||
l.add(numOfRounds);
|
||||
return l;
|
||||
} else {
|
||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,16 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.*;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
/**
|
||||
* trainer for xgboost
|
||||
@@ -108,35 +112,34 @@ public class XGBoost {
|
||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* goes to the unexpected direction in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster) throws XGBoostError {
|
||||
private static void saveCheckpoint(
|
||||
Booster booster,
|
||||
int iter,
|
||||
Set<Integer> checkpointIterations,
|
||||
ExternalCheckpointManager ecm) throws XGBoostError {
|
||||
try {
|
||||
if (checkpointIterations.contains(iter)) {
|
||||
ecm.updateCheckpoint(booster);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
|
||||
throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
|
||||
}
|
||||
}
|
||||
|
||||
public static Booster trainAndSaveCheckpoint(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int numRounds,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster,
|
||||
int checkpointInterval,
|
||||
String checkpointPath,
|
||||
FileSystem fs) throws XGBoostError, IOException {
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
DMatrix[] evalMats;
|
||||
@@ -144,6 +147,11 @@ public class XGBoost {
|
||||
int bestIteration;
|
||||
List<String> names = new ArrayList<String>();
|
||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||
Set<Integer> checkpointIterations = new HashSet<>();
|
||||
ExternalCheckpointManager ecm = null;
|
||||
if (checkpointPath != null) {
|
||||
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||
}
|
||||
|
||||
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
|
||||
names.add(evalEntry.getKey());
|
||||
@@ -158,7 +166,7 @@ public class XGBoost {
|
||||
bestScore = Float.MAX_VALUE;
|
||||
}
|
||||
bestIteration = 0;
|
||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
||||
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||
|
||||
//collect all data matrixs
|
||||
DMatrix[] allMats;
|
||||
@@ -181,14 +189,19 @@ public class XGBoost {
|
||||
booster.setParams(params);
|
||||
}
|
||||
|
||||
//begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
|
||||
if (ecm != null) {
|
||||
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||
}
|
||||
|
||||
// begin to train
|
||||
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||
booster.saveRabitCheckpoint();
|
||||
}
|
||||
|
||||
@@ -224,7 +237,7 @@ public class XGBoost {
|
||||
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
|
||||
Rabit.trackerPrint(String.format(
|
||||
"early stopping after %d rounds away from the best iteration",
|
||||
earlyStoppingRounds));
|
||||
earlyStoppingRounds));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -239,6 +252,44 @@ public class XGBoost {
|
||||
return booster;
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* goes to the unexpected direction in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRounds,
|
||||
Booster booster) throws XGBoostError {
|
||||
try {
|
||||
return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
|
||||
earlyStoppingRounds, booster,
|
||||
-1, null, null);
|
||||
} catch (IOException e) {
|
||||
logger.error("training failed in xgboost4j", e);
|
||||
throw new XGBoostError("training failed in xgboost4j ", e);
|
||||
}
|
||||
}
|
||||
|
||||
private static Integer tryGetIntFromObject(Object o) {
|
||||
if (o instanceof Integer) {
|
||||
return (int)o;
|
||||
|
||||
@@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
|
||||
public XGBoostError(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public XGBoostError(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
|
||||
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
|
||||
extends JavaECM(checkpointPath, fs) {
|
||||
|
||||
def updateCheckpoint(booster: Booster): Unit = {
|
||||
super.updateCheckpoint(booster.booster)
|
||||
}
|
||||
|
||||
def loadCheckpointAsScalaBooster(): Booster = {
|
||||
val loadedBooster = super.loadCheckpointAsBooster()
|
||||
if (loadedBooster == null) {
|
||||
null
|
||||
} else {
|
||||
new Booster(loadedBooster)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.InputStream
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
/**
|
||||
* XGBoost Scala Training function.
|
||||
*/
|
||||
object XGBoost {
|
||||
|
||||
private[scala] def trainAndSaveCheckpoint(
|
||||
dtrain: DMatrix,
|
||||
params: Map[String, Any],
|
||||
numRounds: Int,
|
||||
watches: Map[String, DMatrix] = Map(),
|
||||
metrics: Array[Array[Float]] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0,
|
||||
prevBooster: Booster,
|
||||
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (prevBooster == null) {
|
||||
null
|
||||
} else {
|
||||
prevBooster.booster
|
||||
}
|
||||
val xgboostInJava = checkpointParams.
|
||||
map(cp => {
|
||||
JXGBoost.trainAndSaveCheckpoint(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
|
||||
cp.checkpointInterval,
|
||||
cp.checkpointPath,
|
||||
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
|
||||
}).
|
||||
getOrElse(
|
||||
JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
)
|
||||
if (prevBooster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
prevBooster
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
@@ -55,23 +101,8 @@ object XGBoost {
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0,
|
||||
booster: Booster = null): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (booster == null) {
|
||||
null
|
||||
} else {
|
||||
booster.booster
|
||||
}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
if (booster == null) {
|
||||
new Booster(xgboostInJava)
|
||||
} else {
|
||||
// Avoid creating a new SBooster with the same JBooster
|
||||
booster
|
||||
}
|
||||
trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
|
||||
booster, None)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -126,3 +157,41 @@ object XGBoost {
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] case class ExternalCheckpointParams(
|
||||
checkpointInterval: Int,
|
||||
checkpointPath: String,
|
||||
skipCleanCheckpoint: Boolean)
|
||||
|
||||
private[scala] object ExternalCheckpointParams {
|
||||
|
||||
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
|
||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||
case None | Some(null) | Some("") => null
|
||||
case Some(path: String) => path
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||
s" an instance of String, but current value is ${params("checkpoint_path")}")
|
||||
}
|
||||
|
||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||
case None => 0
|
||||
case Some(freq: Int) => freq
|
||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||
" an instance of Int.")
|
||||
}
|
||||
|
||||
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||
case None => false
|
||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||
" an instance of Boolean")
|
||||
}
|
||||
if (checkpointPath == null || checkpointInterval == 0) {
|
||||
None
|
||||
} else {
|
||||
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ setup(name='xgboost',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7'],
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8'],
|
||||
python_requires='>=3.5',
|
||||
url='https://github.com/dmlc/xgboost')
|
||||
|
||||
@@ -79,6 +79,7 @@ setup(name='xgboost',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7'],
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8'],
|
||||
python_requires='>=3.5',
|
||||
url='https://github.com/dmlc/xgboost')
|
||||
|
||||
@@ -1 +1 @@
|
||||
1.0.0-SNAPSHOT
|
||||
1.0.2
|
||||
|
||||
@@ -5,6 +5,8 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from .core import DMatrix, Booster
|
||||
from .training import train, cv
|
||||
@@ -19,6 +21,12 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if sys.version_info[:2] == (3, 5):
|
||||
warnings.warn(
|
||||
'Python 3.5 support is deprecated; XGBoost will require Python 3.6+ in the near future. ' +
|
||||
'Consider upgrading to Python 3.6+.',
|
||||
FutureWarning)
|
||||
|
||||
VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
|
||||
with open(VERSION_FILE) as f:
|
||||
__version__ = f.read().strip()
|
||||
|
||||
@@ -79,6 +79,14 @@ else:
|
||||
# END NUMPY PATHLIB ATTRIBUTION
|
||||
###############################################################################
|
||||
|
||||
|
||||
def lazy_isinstance(instance, module, name):
|
||||
'''Use string representation to identify a type.'''
|
||||
module = type(instance).__module__ == module
|
||||
name = type(instance).__name__ == name
|
||||
return module and name
|
||||
|
||||
|
||||
# pandas
|
||||
try:
|
||||
from pandas import DataFrame, Series
|
||||
@@ -95,27 +103,6 @@ except ImportError:
|
||||
pandas_concat = None
|
||||
PANDAS_INSTALLED = False
|
||||
|
||||
# dt
|
||||
try:
|
||||
# Workaround for #4473, compatibility with dask
|
||||
if sys.__stdin__ is not None and sys.__stdin__.closed:
|
||||
sys.__stdin__ = None
|
||||
import datatable
|
||||
|
||||
if hasattr(datatable, "Frame"):
|
||||
DataTable = datatable.Frame
|
||||
else:
|
||||
DataTable = datatable.DataTable
|
||||
DT_INSTALLED = True
|
||||
except ImportError:
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class DataTable(object):
|
||||
""" dummy for datatable.DataTable """
|
||||
|
||||
DT_INSTALLED = False
|
||||
|
||||
|
||||
# cudf
|
||||
try:
|
||||
from cudf import DataFrame as CUDF_DataFrame
|
||||
|
||||
@@ -19,9 +19,9 @@ import scipy.sparse
|
||||
|
||||
from .compat import (
|
||||
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str,
|
||||
PANDAS_INSTALLED, DataTable,
|
||||
CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
|
||||
os_fspath, os_PathLike)
|
||||
PANDAS_INSTALLED, CUDF_INSTALLED,
|
||||
CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
|
||||
os_fspath, os_PathLike, lazy_isinstance)
|
||||
from .libpath import find_lib_path
|
||||
|
||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||
@@ -319,7 +319,8 @@ DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
||||
def _maybe_dt_data(data, feature_names, feature_types,
|
||||
meta=None, meta_type=None):
|
||||
"""Validate feature names and types if data table"""
|
||||
if not isinstance(data, DataTable):
|
||||
if (not lazy_isinstance(data, 'datatable', 'Frame') and
|
||||
not lazy_isinstance(data, 'datatable', 'DataTable')):
|
||||
return data, feature_names, feature_types
|
||||
|
||||
if meta and data.shape[1] > 1:
|
||||
@@ -470,7 +471,7 @@ class DMatrix(object):
|
||||
self._init_from_csc(data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
self._init_from_npy2d(data, missing, nthread)
|
||||
elif isinstance(data, DataTable):
|
||||
elif lazy_isinstance(data, 'datatable', 'Frame'):
|
||||
self._init_from_dt(data, nthread)
|
||||
elif hasattr(data, "__cuda_array_interface__"):
|
||||
self._init_from_array_interface(data, missing, nthread)
|
||||
@@ -1052,7 +1053,7 @@ class Booster(object):
|
||||
_check_call(
|
||||
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
|
||||
self.__dict__.update(state)
|
||||
elif isinstance(model_file, (STRING_TYPES, os_PathLike)):
|
||||
elif isinstance(model_file, (STRING_TYPES, os_PathLike, bytearray)):
|
||||
self.load_model(model_file)
|
||||
elif model_file is None:
|
||||
pass
|
||||
@@ -1512,7 +1513,8 @@ class Booster(object):
|
||||
return ctypes2buffer(cptr, length.value)
|
||||
|
||||
def load_model(self, fname):
|
||||
"""Load the model from a file, local or as URI.
|
||||
"""Load the model from a file or bytearray. Path to file can be local
|
||||
or as an URI.
|
||||
|
||||
The model is loaded from an XGBoost format which is universal among the
|
||||
various XGBoost interfaces. Auxiliary attributes of the Python Booster
|
||||
@@ -1530,6 +1532,12 @@ class Booster(object):
|
||||
# from URL.
|
||||
_check_call(_LIB.XGBoosterLoadModel(
|
||||
self.handle, c_str(os_fspath(fname))))
|
||||
elif isinstance(fname, bytearray):
|
||||
buf = fname
|
||||
length = c_bst_ulong(len(buf))
|
||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr,
|
||||
length))
|
||||
else:
|
||||
raise TypeError('Unknown file type: ', fname)
|
||||
|
||||
|
||||
@@ -600,6 +600,7 @@ class DaskXGBRegressor(DaskScikitLearnBase):
|
||||
results = train(self.client, params, dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals)
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self._Booster = results['booster']
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
|
||||
@@ -200,7 +200,7 @@ Parameters
|
||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||
['estimators', 'model', 'objective'])
|
||||
class XGBModel(XGBModelBase):
|
||||
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name, missing-docstring
|
||||
# pylint: disable=too-many-arguments, too-many-instance-attributes, missing-docstring
|
||||
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
|
||||
verbosity=None, objective=None, booster=None,
|
||||
tree_method=None, n_jobs=None, gamma=None,
|
||||
@@ -210,7 +210,8 @@ class XGBModel(XGBModelBase):
|
||||
scale_pos_weight=None, base_score=None, random_state=None,
|
||||
missing=None, num_parallel_tree=None,
|
||||
monotone_constraints=None, interaction_constraints=None,
|
||||
importance_type="gain", gpu_id=None, **kwargs):
|
||||
importance_type="gain", gpu_id=None,
|
||||
validate_parameters=False, **kwargs):
|
||||
if not SKLEARN_INSTALLED:
|
||||
raise XGBoostError(
|
||||
'sklearn needs to be installed in order to use this module')
|
||||
@@ -243,6 +244,10 @@ class XGBModel(XGBModelBase):
|
||||
self.interaction_constraints = interaction_constraints
|
||||
self.importance_type = importance_type
|
||||
self.gpu_id = gpu_id
|
||||
# Parameter validation is not working with Scikit-Learn interface, as
|
||||
# it passes all paraemters into XGBoost core, whether they are used or
|
||||
# not.
|
||||
self.validate_parameters = validate_parameters
|
||||
|
||||
def __setstate__(self, state):
|
||||
# backward compatibility code
|
||||
@@ -314,11 +319,35 @@ class XGBModel(XGBModelBase):
|
||||
if isinstance(params['random_state'], np.random.RandomState):
|
||||
params['random_state'] = params['random_state'].randint(
|
||||
np.iinfo(np.int32).max)
|
||||
# Parameter validation is not working with Scikit-Learn interface, as
|
||||
# it passes all paraemters into XGBoost core, whether they are used or
|
||||
# not.
|
||||
if 'validate_parameters' not in params.keys():
|
||||
params['validate_parameters'] = False
|
||||
|
||||
def parse_parameter(value):
|
||||
for t in (int, float):
|
||||
try:
|
||||
ret = t(value)
|
||||
return ret
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
# Get internal parameter values
|
||||
try:
|
||||
config = json.loads(self.get_booster().save_config())
|
||||
stack = [config]
|
||||
internal = {}
|
||||
while stack:
|
||||
obj = stack.pop()
|
||||
for k, v in obj.items():
|
||||
if k.endswith('_param'):
|
||||
for p_k, p_v in v.items():
|
||||
internal[p_k] = p_v
|
||||
elif isinstance(v, dict):
|
||||
stack.append(v)
|
||||
|
||||
for k, v in internal.items():
|
||||
if k in params.keys() and params[k] is None:
|
||||
params[k] = parse_parameter(v)
|
||||
except XGBoostError:
|
||||
pass
|
||||
return params
|
||||
|
||||
def get_xgb_params(self):
|
||||
@@ -405,8 +434,8 @@ class XGBModel(XGBModelBase):
|
||||
self.classes_ = np.array(v)
|
||||
continue
|
||||
if k == 'type' and type(self).__name__ != v:
|
||||
msg = f'Current model type: {type(self).__name__}, ' + \
|
||||
f'type of model in file: {v}'
|
||||
msg = 'Current model type: {}, '.format(type(self).__name__) + \
|
||||
'type of model in file: {}'.format(v)
|
||||
raise TypeError(msg)
|
||||
if k == 'type':
|
||||
continue
|
||||
|
||||
@@ -38,7 +38,7 @@ def _train_internal(params, dtrain,
|
||||
|
||||
_params = dict(params) if isinstance(params, list) else params
|
||||
|
||||
if 'num_parallel_tree' in _params and params[
|
||||
if 'num_parallel_tree' in _params and _params[
|
||||
'num_parallel_tree'] is not None:
|
||||
num_parallel_tree = _params['num_parallel_tree']
|
||||
nboost //= num_parallel_tree
|
||||
|
||||
@@ -663,7 +663,11 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
||||
* \brief fill a histogram by zeroes
|
||||
*/
|
||||
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
|
||||
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
std::fill(hist.begin() + begin, hist.begin() + end, tree::GradStats());
|
||||
#else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats));
|
||||
#endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
}
|
||||
|
||||
/*!
|
||||
|
||||
@@ -117,14 +117,16 @@ std::string LoadSequentialFile(std::string fname) {
|
||||
size_t f_size_bytes = fs.st_size;
|
||||
buffer.resize(f_size_bytes + 1);
|
||||
int32_t fd = open(fname.c_str(), O_RDONLY);
|
||||
#if defined(__linux__)
|
||||
posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
#endif // defined(__linux__)
|
||||
ssize_t bytes_read = read(fd, &buffer[0], f_size_bytes);
|
||||
if (bytes_read < 0) {
|
||||
close(fd);
|
||||
ReadErr();
|
||||
}
|
||||
close(fd);
|
||||
#else
|
||||
#else // defined(__unix__)
|
||||
FILE *f = fopen(fname.c_str(), "r");
|
||||
if (f == NULL) {
|
||||
std::string msg;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
* Copyright (c) by Contributors 2019
|
||||
*/
|
||||
#include <cctype>
|
||||
#include <locale>
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
@@ -24,7 +25,7 @@ void JsonWriter::Visit(JsonArray const* arr) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto const& value = vec[i];
|
||||
this->Save(value);
|
||||
if (i != size-1) { Write(", "); }
|
||||
if (i != size-1) { Write(","); }
|
||||
}
|
||||
this->Write("]");
|
||||
}
|
||||
@@ -38,7 +39,7 @@ void JsonWriter::Visit(JsonObject const* obj) {
|
||||
size_t size = obj->getObject().size();
|
||||
|
||||
for (auto& value : obj->getObject()) {
|
||||
this->Write("\"" + value.first + "\": ");
|
||||
this->Write("\"" + value.first + "\":");
|
||||
this->Save(value.second);
|
||||
|
||||
if (i != size-1) {
|
||||
@@ -692,47 +693,23 @@ Json JsonReader::ParseBoolean() {
|
||||
return Json{JsonBoolean{result}};
|
||||
}
|
||||
|
||||
// This is an ad-hoc solution for writing numeric value in standard way. We need to add
|
||||
// a locale independent way of writing stream like `std::{from, to}_chars' from C++-17.
|
||||
// FIXME(trivialfis): Remove this.
|
||||
class GlobalCLocale {
|
||||
std::locale ori_;
|
||||
|
||||
public:
|
||||
GlobalCLocale() : ori_{std::locale()} {
|
||||
std::string const name {"C"};
|
||||
try {
|
||||
std::locale::global(std::locale(name.c_str()));
|
||||
} catch (std::runtime_error const& e) {
|
||||
LOG(FATAL) << "Failed to set locale: " << name;
|
||||
}
|
||||
}
|
||||
~GlobalCLocale() {
|
||||
std::locale::global(ori_);
|
||||
}
|
||||
};
|
||||
|
||||
Json Json::Load(StringView str) {
|
||||
GlobalCLocale guard;
|
||||
JsonReader reader(str);
|
||||
Json json{reader.Load()};
|
||||
return json;
|
||||
}
|
||||
|
||||
Json Json::Load(JsonReader* reader) {
|
||||
GlobalCLocale guard;
|
||||
Json json{reader->Load()};
|
||||
return json;
|
||||
}
|
||||
|
||||
void Json::Dump(Json json, std::ostream *stream, bool pretty) {
|
||||
GlobalCLocale guard;
|
||||
JsonWriter writer(stream, pretty);
|
||||
writer.Save(json);
|
||||
}
|
||||
|
||||
void Json::Dump(Json json, std::string* str, bool pretty) {
|
||||
GlobalCLocale guard;
|
||||
std::stringstream ss;
|
||||
JsonWriter writer(&ss, pretty);
|
||||
writer.Save(json);
|
||||
|
||||
@@ -15,12 +15,23 @@
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
#define OBSERVER_PRINT LOG(INFO)
|
||||
#define OBSERVER_ENDL ""
|
||||
#define OBSERVER_NEWLINE ""
|
||||
#else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
#define OBSERVER_PRINT std::cout
|
||||
#define OBSERVER_ENDL std::endl
|
||||
#define OBSERVER_NEWLINE "\n"
|
||||
#endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
|
||||
namespace xgboost {
|
||||
/*\brief An observer for logging internal data structures.
|
||||
*
|
||||
* This class is designed to be `diff` tool friendly, which means it uses plain
|
||||
* `std::cout` for printing to avoid the time information emitted by `LOG(DEBUG)` or
|
||||
* similiar facilities.
|
||||
* similiar facilities. Exception: use `LOG(INFO)` for the R package, to comply
|
||||
* with CRAN policy.
|
||||
*/
|
||||
class TrainingObserver {
|
||||
#if defined(XGBOOST_USE_DEBUG_OUTPUT)
|
||||
@@ -32,17 +43,17 @@ class TrainingObserver {
|
||||
public:
|
||||
void Update(int32_t iter) const {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
std::cout << "Iter: " << iter << std::endl;
|
||||
OBSERVER_PRINT << "Iter: " << iter << OBSERVER_ENDL;
|
||||
}
|
||||
/*\brief Observe tree. */
|
||||
void Observe(RegTree const& tree) {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
std::cout << "Tree:" << std::endl;
|
||||
OBSERVER_PRINT << "Tree:" << OBSERVER_ENDL;
|
||||
Json j_tree {Object()};
|
||||
tree.SaveModel(&j_tree);
|
||||
std::string str;
|
||||
Json::Dump(j_tree, &str, true);
|
||||
std::cout << str << std::endl;
|
||||
OBSERVER_PRINT << str << OBSERVER_ENDL;
|
||||
}
|
||||
/*\brief Observe tree. */
|
||||
void Observe(RegTree const* p_tree) {
|
||||
@@ -54,15 +65,15 @@ class TrainingObserver {
|
||||
template <typename T>
|
||||
void Observe(std::vector<T> const& h_vec, std::string name) const {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
std::cout << "Procedure: " << name << std::endl;
|
||||
OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL;
|
||||
|
||||
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||
std::cout << h_vec[i] << ", ";
|
||||
OBSERVER_PRINT << h_vec[i] << ", ";
|
||||
if (i % 8 == 0) {
|
||||
std::cout << '\n';
|
||||
OBSERVER_PRINT << OBSERVER_NEWLINE;
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
OBSERVER_PRINT << OBSERVER_ENDL;
|
||||
}
|
||||
/*\brief Observe data hosted by `HostDeviceVector'. */
|
||||
template <typename T>
|
||||
@@ -85,16 +96,16 @@ class TrainingObserver {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
|
||||
Json obj {toJson(p)};
|
||||
std::cout << "Parameter: " << name << ":\n" << obj << std::endl;
|
||||
OBSERVER_PRINT << "Parameter: " << name << ":\n" << obj << OBSERVER_ENDL;
|
||||
}
|
||||
/*\brief Observe parameters provided by users. */
|
||||
void Observe(Args const& args) const {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
|
||||
for (auto kv : args) {
|
||||
std::cout << kv.first << ": " << kv.second << "\n";
|
||||
OBSERVER_PRINT << kv.first << ": " << kv.second << OBSERVER_NEWLINE;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
OBSERVER_PRINT << OBSERVER_ENDL;
|
||||
}
|
||||
|
||||
/*\brief Get a global instance. */
|
||||
|
||||
@@ -89,10 +89,10 @@ void Monitor::PrintStatistics(StatMap const& statistics) const {
|
||||
"Timer for " << kv.first << " did not get stopped properly.";
|
||||
continue;
|
||||
}
|
||||
std::cout << kv.first << ": " << static_cast<double>(kv.second.second) / 1e+6
|
||||
<< "s, " << kv.second.first << " calls @ "
|
||||
<< kv.second.second
|
||||
<< "us" << std::endl;
|
||||
LOG(CONSOLE) << kv.first << ": " << static_cast<double>(kv.second.second) / 1e+6
|
||||
<< "s, " << kv.second.first << " calls @ "
|
||||
<< kv.second.second
|
||||
<< "us" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,10 +107,9 @@ void Monitor::Print() const {
|
||||
if (rabit::GetRank() == 0) {
|
||||
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
||||
for (size_t i = 0; i < world.size(); ++i) {
|
||||
std::cout << "From rank: " << i << ": " << std::endl;
|
||||
LOG(CONSOLE) << "From rank: " << i << ": " << std::endl;
|
||||
auto const& statistic = world[i];
|
||||
this->PrintStatistics(statistic);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -123,7 +122,6 @@ void Monitor::Print() const {
|
||||
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
||||
this->PrintStatistics(stat_map);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
|
||||
201
src/learner.cc
201
src/learner.cc
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2019 by Contributors
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* \file learner.cc
|
||||
* \brief Implementation of learning algorithm.
|
||||
* \author Tianqi Chen
|
||||
@@ -67,19 +67,26 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
/* \brief global bias */
|
||||
bst_float base_score;
|
||||
/* \brief number of features */
|
||||
unsigned num_feature;
|
||||
uint32_t num_feature;
|
||||
/* \brief number of classes, if it is multi-class classification */
|
||||
int num_class;
|
||||
int32_t num_class;
|
||||
/*! \brief Model contain additional properties */
|
||||
int contain_extra_attrs;
|
||||
int32_t contain_extra_attrs;
|
||||
/*! \brief Model contain eval metrics */
|
||||
int contain_eval_metrics;
|
||||
int32_t contain_eval_metrics;
|
||||
/*! \brief the version of XGBoost. */
|
||||
uint32_t major_version;
|
||||
uint32_t minor_version;
|
||||
/*! \brief reserved field */
|
||||
int reserved[29];
|
||||
int reserved[27];
|
||||
/*! \brief constructor */
|
||||
LearnerModelParamLegacy() {
|
||||
std::memset(this, 0, sizeof(LearnerModelParamLegacy));
|
||||
base_score = 0.5f;
|
||||
major_version = std::get<0>(Version::Self());
|
||||
minor_version = std::get<1>(Version::Self());
|
||||
static_assert(sizeof(LearnerModelParamLegacy) == 136,
|
||||
"Do not change the size of this struct, as it will break binary IO.");
|
||||
}
|
||||
// Skip other legacy fields.
|
||||
Json ToJson() const {
|
||||
@@ -117,8 +124,9 @@ LearnerModelParam::LearnerModelParam(
|
||||
LearnerModelParamLegacy const &user_param, float base_margin)
|
||||
: base_score{base_margin}, num_feature{user_param.num_feature},
|
||||
num_output_group{user_param.num_class == 0
|
||||
? 1
|
||||
: static_cast<uint32_t>(user_param.num_class)} {}
|
||||
? 1
|
||||
: static_cast<uint32_t>(user_param.num_class)}
|
||||
{}
|
||||
|
||||
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
// data split mode, can be row, col, or none.
|
||||
@@ -140,7 +148,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
.describe("Data split mode for distributed training.");
|
||||
DMLC_DECLARE_FIELD(disable_default_eval_metric)
|
||||
.set_default(0)
|
||||
.describe("flag to disable default metric. Set to >0 to disable");
|
||||
.describe("Flag to disable default metric. Set to >0 to disable");
|
||||
DMLC_DECLARE_FIELD(booster)
|
||||
.set_default("gbtree")
|
||||
.describe("Gradient booster used for training.");
|
||||
@@ -200,6 +208,7 @@ class LearnerImpl : public Learner {
|
||||
Args args = {cfg_.cbegin(), cfg_.cend()};
|
||||
|
||||
tparam_.UpdateAllowUnknown(args);
|
||||
auto mparam_backup = mparam_;
|
||||
mparam_.UpdateAllowUnknown(args);
|
||||
generic_parameters_.UpdateAllowUnknown(args);
|
||||
generic_parameters_.CheckDeprecated();
|
||||
@@ -217,17 +226,33 @@ class LearnerImpl : public Learner {
|
||||
|
||||
// set seed only before the model is initialized
|
||||
common::GlobalRandom().seed(generic_parameters_.seed);
|
||||
|
||||
// must precede configure gbm since num_features is required for gbm
|
||||
this->ConfigureNumFeatures();
|
||||
args = {cfg_.cbegin(), cfg_.cend()}; // renew
|
||||
this->ConfigureObjective(old_tparam, &args);
|
||||
this->ConfigureGBM(old_tparam, args);
|
||||
this->ConfigureMetrics(args);
|
||||
|
||||
// Before 1.0.0, we save `base_score` into binary as a transformed value by objective.
|
||||
// After 1.0.0 we save the value provided by user and keep it immutable instead. To
|
||||
// keep the stability, we initialize it in binary LoadModel instead of configuration.
|
||||
// Under what condition should we omit the transformation:
|
||||
//
|
||||
// - base_score is loaded from old binary model.
|
||||
//
|
||||
// What are the other possible conditions:
|
||||
//
|
||||
// - model loaded from new binary or JSON.
|
||||
// - model is created from scratch.
|
||||
// - model is configured second time due to change of parameter
|
||||
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
|
||||
learner_model_param_ = LearnerModelParam(mparam_,
|
||||
obj_->ProbToMargin(mparam_.base_score));
|
||||
}
|
||||
|
||||
this->ConfigureGBM(old_tparam, args);
|
||||
generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU());
|
||||
|
||||
learner_model_param_ = LearnerModelParam(mparam_,
|
||||
obj_->ProbToMargin(mparam_.base_score));
|
||||
this->ConfigureMetrics(args);
|
||||
|
||||
this->need_configuration_ = false;
|
||||
if (generic_parameters_.validate_parameters) {
|
||||
@@ -269,10 +294,7 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
}
|
||||
}
|
||||
auto learner_model_param = mparam_.ToJson();
|
||||
for (auto const& kv : get<Object>(learner_model_param)) {
|
||||
keys.emplace_back(kv.first);
|
||||
}
|
||||
|
||||
keys.emplace_back(kEvalMetric);
|
||||
keys.emplace_back("verbosity");
|
||||
keys.emplace_back("num_output_group");
|
||||
@@ -340,9 +362,6 @@ class LearnerImpl : public Learner {
|
||||
cache_));
|
||||
gbm_->LoadModel(gradient_booster);
|
||||
|
||||
learner_model_param_ = LearnerModelParam(mparam_,
|
||||
obj_->ProbToMargin(mparam_.base_score));
|
||||
|
||||
auto const& j_attributes = get<Object const>(learner.at("attributes"));
|
||||
attributes_.clear();
|
||||
for (auto const& kv : j_attributes) {
|
||||
@@ -425,6 +444,7 @@ class LearnerImpl : public Learner {
|
||||
auto& learner_parameters = out["learner"];
|
||||
|
||||
learner_parameters["learner_train_param"] = toJson(tparam_);
|
||||
learner_parameters["learner_model_param"] = mparam_.ToJson();
|
||||
learner_parameters["gradient_booster"] = Object();
|
||||
auto& gradient_booster = learner_parameters["gradient_booster"];
|
||||
gbm_->SaveConfig(&gradient_booster);
|
||||
@@ -461,6 +481,7 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
if (header[0] == '{') {
|
||||
// Dispatch to JSON
|
||||
auto json_stream = common::FixedSizeStream(&fp);
|
||||
std::string buffer;
|
||||
json_stream.Take(&buffer);
|
||||
@@ -473,25 +494,10 @@ class LearnerImpl : public Learner {
|
||||
// read parameter
|
||||
CHECK_EQ(fi->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
|
||||
<< "BoostLearner: wrong model format";
|
||||
{
|
||||
// backward compatibility code for compatible with old model type
|
||||
// for new model, Read(&name_obj_) is suffice
|
||||
uint64_t len;
|
||||
CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len));
|
||||
if (len >= std::numeric_limits<unsigned>::max()) {
|
||||
int gap;
|
||||
CHECK_EQ(fi->Read(&gap, sizeof(gap)), sizeof(gap))
|
||||
<< "BoostLearner: wrong model format";
|
||||
len = len >> static_cast<uint64_t>(32UL);
|
||||
}
|
||||
if (len != 0) {
|
||||
tparam_.objective.resize(len);
|
||||
CHECK_EQ(fi->Read(&tparam_.objective[0], len), len)
|
||||
<< "BoostLearner: wrong model format";
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format";
|
||||
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
|
||||
// duplicated code with LazyInitModel
|
||||
|
||||
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
|
||||
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_,
|
||||
&learner_model_param_, cache_));
|
||||
@@ -510,34 +516,57 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
attributes_ = std::map<std::string, std::string>(attr.begin(), attr.end());
|
||||
}
|
||||
if (tparam_.objective == "count:poisson") {
|
||||
std::string max_delta_step;
|
||||
fi->Read(&max_delta_step);
|
||||
cfg_["max_delta_step"] = max_delta_step;
|
||||
bool warn_old_model { false };
|
||||
if (attributes_.find("count_poisson_max_delta_step") != attributes_.cend()) {
|
||||
// Loading model from < 1.0.0, objective is not saved.
|
||||
cfg_["max_delta_step"] = attributes_.at("count_poisson_max_delta_step");
|
||||
attributes_.erase("count_poisson_max_delta_step");
|
||||
warn_old_model = true;
|
||||
} else {
|
||||
warn_old_model = false;
|
||||
}
|
||||
if (mparam_.contain_eval_metrics != 0) {
|
||||
std::vector<std::string> metr;
|
||||
fi->Read(&metr);
|
||||
for (auto name : metr) {
|
||||
metrics_.emplace_back(Metric::Create(name, &generic_parameters_));
|
||||
|
||||
if (mparam_.major_version >= 1) {
|
||||
learner_model_param_ = LearnerModelParam(mparam_,
|
||||
obj_->ProbToMargin(mparam_.base_score));
|
||||
} else {
|
||||
// Before 1.0.0, base_score is saved as a transformed value, and there's no version
|
||||
// attribute in the saved model.
|
||||
learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score);
|
||||
warn_old_model = true;
|
||||
}
|
||||
if (attributes_.find("objective") != attributes_.cend()) {
|
||||
auto obj_str = attributes_.at("objective");
|
||||
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
|
||||
obj_->LoadConfig(j_obj);
|
||||
attributes_.erase("objective");
|
||||
} else {
|
||||
warn_old_model = true;
|
||||
}
|
||||
if (attributes_.find("metrics") != attributes_.cend()) {
|
||||
auto metrics_str = attributes_.at("metrics");
|
||||
std::vector<std::string> names { common::Split(metrics_str, ';') };
|
||||
attributes_.erase("metrics");
|
||||
for (auto const& n : names) {
|
||||
this->SetParam(kEvalMetric, n);
|
||||
}
|
||||
}
|
||||
|
||||
if (warn_old_model) {
|
||||
LOG(WARNING) << "Loading model from XGBoost < 1.0.0, consider saving it "
|
||||
"again for improved compatibility";
|
||||
}
|
||||
|
||||
// Renew the version.
|
||||
mparam_.major_version = std::get<0>(Version::Self());
|
||||
mparam_.minor_version = std::get<1>(Version::Self());
|
||||
|
||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||
|
||||
auto n = tparam_.__DICT__();
|
||||
cfg_.insert(n.cbegin(), n.cend());
|
||||
|
||||
Args args = {cfg_.cbegin(), cfg_.cend()};
|
||||
generic_parameters_.UpdateAllowUnknown(args);
|
||||
gbm_->Configure(args);
|
||||
obj_->Configure({cfg_.begin(), cfg_.end()});
|
||||
|
||||
for (auto& p_metric : metrics_) {
|
||||
p_metric->Configure({cfg_.begin(), cfg_.end()});
|
||||
}
|
||||
|
||||
// copy dsplit from config since it will not run again during restore
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
@@ -554,15 +583,8 @@ class LearnerImpl : public Learner {
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
|
||||
std::vector<std::pair<std::string, std::string> > extra_attr;
|
||||
// extra attributed to be added just before saving
|
||||
if (tparam_.objective == "count:poisson") {
|
||||
auto it = cfg_.find("max_delta_step");
|
||||
if (it != cfg_.end()) {
|
||||
// write `max_delta_step` parameter as extra attribute of booster
|
||||
mparam.contain_extra_attrs = 1;
|
||||
extra_attr.emplace_back("count_poisson_max_delta_step", it->second);
|
||||
}
|
||||
}
|
||||
mparam.contain_extra_attrs = 1;
|
||||
|
||||
{
|
||||
std::vector<std::string> saved_params;
|
||||
// check if rabit_bootstrap_cache were set to non zero before adding to checkpoint
|
||||
@@ -579,6 +601,24 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
// Similar to JSON model IO, we save the objective.
|
||||
Json j_obj { Object() };
|
||||
obj_->SaveConfig(&j_obj);
|
||||
std::string obj_doc;
|
||||
Json::Dump(j_obj, &obj_doc);
|
||||
extra_attr.emplace_back("objective", obj_doc);
|
||||
}
|
||||
// As of 1.0.0, JVM Package and R Package uses Save/Load model for serialization.
|
||||
// Remove this part once they are ported to use actual serialization methods.
|
||||
if (mparam.contain_eval_metrics != 0) {
|
||||
std::stringstream os;
|
||||
for (auto& ev : metrics_) {
|
||||
os << ev->Name() << ";";
|
||||
}
|
||||
extra_attr.emplace_back("metrics", os.str());
|
||||
}
|
||||
|
||||
fo->Write(&mparam, sizeof(LearnerModelParamLegacy));
|
||||
fo->Write(tparam_.objective);
|
||||
fo->Write(tparam_.booster);
|
||||
@@ -589,26 +629,7 @@ class LearnerImpl : public Learner {
|
||||
attr[kv.first] = kv.second;
|
||||
}
|
||||
fo->Write(std::vector<std::pair<std::string, std::string>>(
|
||||
attr.begin(), attr.end()));
|
||||
}
|
||||
if (tparam_.objective == "count:poisson") {
|
||||
auto it = cfg_.find("max_delta_step");
|
||||
if (it != cfg_.end()) {
|
||||
fo->Write(it->second);
|
||||
} else {
|
||||
// recover value of max_delta_step from extra attributes
|
||||
auto it2 = attributes_.find("count_poisson_max_delta_step");
|
||||
const std::string max_delta_step
|
||||
= (it2 != attributes_.end()) ? it2->second : kMaxDeltaStepDefaultValue;
|
||||
fo->Write(max_delta_step);
|
||||
}
|
||||
}
|
||||
if (mparam.contain_eval_metrics != 0) {
|
||||
std::vector<std::string> metr;
|
||||
for (auto& ev : metrics_) {
|
||||
metr.emplace_back(ev->Name());
|
||||
}
|
||||
fo->Write(metr);
|
||||
attr.begin(), attr.end()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -663,11 +684,13 @@ class LearnerImpl : public Learner {
|
||||
|
||||
If you are loading a serialized model (like pickle in Python) generated by older
|
||||
XGBoost, please export the model by calling `Booster.save_model` from that version
|
||||
first, then load it back in current version. See:
|
||||
first, then load it back in current version. There's a simple script for helping
|
||||
the process. See:
|
||||
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
||||
|
||||
for more details about differences between saving model and serializing.
|
||||
for reference to the script, and more details about differences between saving model and
|
||||
serializing.
|
||||
|
||||
)doc";
|
||||
int64_t sz {-1};
|
||||
@@ -856,7 +879,8 @@ class LearnerImpl : public Learner {
|
||||
|
||||
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
|
||||
// Once binary IO is gone, NONE of these config is useful.
|
||||
if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0") {
|
||||
if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0" &&
|
||||
tparam_.objective != "multi:softprob") {
|
||||
cfg_["num_output_group"] = cfg_["num_class"];
|
||||
if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) {
|
||||
tparam_.objective = "multi:softmax";
|
||||
@@ -921,7 +945,6 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
CHECK_NE(mparam_.num_feature, 0)
|
||||
<< "0 feature is supplied. Are you using raw Booster interface?";
|
||||
learner_model_param_.num_feature = mparam_.num_feature;
|
||||
// Remove these once binary IO is gone.
|
||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||
|
||||
@@ -3,6 +3,7 @@ ARG CMAKE_VERSION=3.12
|
||||
|
||||
# Environment
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
SHELL ["/bin/bash", "-c"] # Use Bash as shell
|
||||
|
||||
# Install all basic requirements
|
||||
RUN \
|
||||
@@ -19,10 +20,17 @@ ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
# Install Python packages
|
||||
# Create new Conda environment with Python 3.5
|
||||
RUN conda create -n py35 python=3.5 && \
|
||||
source activate py35 && \
|
||||
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
|
||||
source deactivate
|
||||
|
||||
# Install Python packages in default env
|
||||
RUN \
|
||||
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh recommonmark guzzle_sphinx_theme mock \
|
||||
breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 jsonschema && \
|
||||
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh \
|
||||
recommonmark guzzle_sphinx_theme mock breathe graphviz \
|
||||
pytest scikit-learn wheel kubernetes urllib3 jsonschema boto3 && \
|
||||
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \
|
||||
pip install "dask[complete]"
|
||||
|
||||
|
||||
35
tests/ci_build/deploy_jvm_packages.sh
Executable file
35
tests/ci_build/deploy_jvm_packages.sh
Executable file
@@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 [spark version]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
spark_version=$1
|
||||
|
||||
# Initialize local Maven repository
|
||||
./tests/ci_build/initialize_maven.sh
|
||||
|
||||
rm -rf build/
|
||||
cd jvm-packages
|
||||
|
||||
# Re-build package without Mock Rabit
|
||||
# Deploy to S3 bucket xgboost-maven-repo
|
||||
mvn --no-transfer-progress package deploy -P release-to-s3 -Dspark.version=${spark_version} -DskipTests
|
||||
|
||||
# Compile XGBoost4J with Scala 2.11 too
|
||||
mvn clean
|
||||
# Rename artifactId of all XGBoost4J packages with suffix _2.11
|
||||
sed -i -e 's/<artifactId>xgboost\(.*\)_[0-9\.]\+/<artifactId>xgboost\1_2.11/' $(find . -name pom.xml)
|
||||
# Modify scala.version and scala.binary.version fields
|
||||
sed -i -e 's/<scala\.version>[0-9\.]\+/<scala.version>2.11.12/' $(find . -name pom.xml)
|
||||
sed -i -e 's/<scala\.binary\.version>[0-9\.]\+/<scala.binary.version>2.11/' $(find . -name pom.xml)
|
||||
|
||||
# Re-build and deploy
|
||||
mvn --no-transfer-progress package deploy -P release-to-s3 -Dspark.version=${spark_version} -DskipTests
|
||||
|
||||
set +x
|
||||
set +e
|
||||
@@ -5,31 +5,35 @@ set -x
|
||||
suite=$1
|
||||
|
||||
# Install XGBoost Python package
|
||||
wheel_found=0
|
||||
for file in python-package/dist/*.whl
|
||||
do
|
||||
if [ -e "${file}" ]
|
||||
function install_xgboost {
|
||||
wheel_found=0
|
||||
for file in python-package/dist/*.whl
|
||||
do
|
||||
if [ -e "${file}" ]
|
||||
then
|
||||
pip install --user "${file}"
|
||||
wheel_found=1
|
||||
break # need just one
|
||||
fi
|
||||
done
|
||||
if [ "$wheel_found" -eq 0 ]
|
||||
then
|
||||
pip install --user "${file}"
|
||||
wheel_found=1
|
||||
break # need just one
|
||||
pushd .
|
||||
cd python-package
|
||||
python setup.py install --user
|
||||
popd
|
||||
fi
|
||||
done
|
||||
if [ "$wheel_found" -eq 0 ]
|
||||
then
|
||||
pushd .
|
||||
cd python-package
|
||||
python setup.py install --user
|
||||
popd
|
||||
fi
|
||||
}
|
||||
|
||||
# Run specified test suite
|
||||
case "$suite" in
|
||||
gpu)
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu
|
||||
;;
|
||||
|
||||
mgpu)
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu
|
||||
cd tests/distributed
|
||||
./runtests-gpu.sh
|
||||
@@ -39,17 +43,25 @@ case "$suite" in
|
||||
|
||||
cudf)
|
||||
source activate cudf_test
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py tests/python-gpu/test_from_cupy.py
|
||||
;;
|
||||
|
||||
cpu)
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace tests/python
|
||||
cd tests/distributed
|
||||
./runtests.sh
|
||||
;;
|
||||
|
||||
cpu-py35)
|
||||
source activate py35
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace tests/python
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Usage: $0 {gpu|mgpu|cudf|cpu}"
|
||||
echo "Usage: $0 {gpu|mgpu|cudf|cpu|cpu-py35}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
@@ -54,7 +54,7 @@ TEST(Version, Basic) {
|
||||
|
||||
ptr = 0;
|
||||
v = std::stoi(str, &ptr);
|
||||
ASSERT_EQ(v, XGBOOST_VER_MINOR) << "patch: " << v;;
|
||||
ASSERT_EQ(v, XGBOOST_VER_PATCH) << "patch: " << v;;
|
||||
|
||||
str = str.substr(ptr);
|
||||
ASSERT_EQ(str.size(), 0);
|
||||
|
||||
@@ -180,6 +180,41 @@ TEST(Learner, JsonModelIO) {
|
||||
delete pp_dmat;
|
||||
}
|
||||
|
||||
TEST(Learner, BinaryModelIO) {
|
||||
size_t constexpr kRows = 8;
|
||||
int32_t constexpr kIters = 4;
|
||||
auto pp_dmat = CreateDMatrix(kRows, 10, 0);
|
||||
std::shared_ptr<DMatrix> p_dmat {*pp_dmat};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
|
||||
learner->SetParam("eval_metric", "rmsle");
|
||||
learner->Configure();
|
||||
for (int32_t iter = 0; iter < kIters; ++iter) {
|
||||
learner->UpdateOneIter(iter, p_dmat.get());
|
||||
}
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
std::string const fname = tempdir.path + "binary_model_io.bin";
|
||||
{
|
||||
// Make sure the write is complete before loading.
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||
learner->SaveModel(fo.get());
|
||||
}
|
||||
|
||||
learner.reset(Learner::Create({p_dmat}));
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
|
||||
learner->LoadModel(fi.get());
|
||||
learner->Configure();
|
||||
Json config { Object() };
|
||||
learner->SaveConfig(&config);
|
||||
std::string config_str;
|
||||
Json::Dump(config, &config_str);
|
||||
ASSERT_NE(config_str.find("rmsle"), std::string::npos);
|
||||
ASSERT_EQ(config_str.find("WARNING"), std::string::npos);
|
||||
|
||||
delete pp_dmat;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
// Tests for automatic GPU configuration.
|
||||
TEST(Learner, GPUConfiguration) {
|
||||
|
||||
148
tests/python/generate_models.py
Normal file
148
tests/python/generate_models.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import xgboost
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
kRounds = 2
|
||||
kRows = 1000
|
||||
kCols = 4
|
||||
kForests = 2
|
||||
kMaxDepth = 2
|
||||
kClasses = 3
|
||||
|
||||
X = np.random.randn(kRows, kCols)
|
||||
w = np.random.uniform(size=kRows)
|
||||
|
||||
version = xgboost.__version__
|
||||
|
||||
np.random.seed(1994)
|
||||
target_dir = 'models'
|
||||
|
||||
|
||||
def booster_bin(model):
|
||||
return os.path.join(target_dir,
|
||||
'xgboost-' + version + '.' + model + '.bin')
|
||||
|
||||
|
||||
def booster_json(model):
|
||||
return os.path.join(target_dir,
|
||||
'xgboost-' + version + '.' + model + '.json')
|
||||
|
||||
|
||||
def skl_bin(model):
|
||||
return os.path.join(target_dir,
|
||||
'xgboost_scikit-' + version + '.' + model + '.bin')
|
||||
|
||||
|
||||
def skl_json(model):
|
||||
return os.path.join(target_dir,
|
||||
'xgboost_scikit-' + version + '.' + model + '.json')
|
||||
|
||||
|
||||
def generate_regression_model():
|
||||
print('Regression')
|
||||
y = np.random.randn(kRows)
|
||||
|
||||
data = xgboost.DMatrix(X, label=y, weight=w)
|
||||
booster = xgboost.train({'tree_method': 'hist',
|
||||
'num_parallel_tree': kForests,
|
||||
'max_depth': kMaxDepth},
|
||||
num_boost_round=kRounds, dtrain=data)
|
||||
booster.save_model(booster_bin('reg'))
|
||||
booster.save_model(booster_json('reg'))
|
||||
|
||||
reg = xgboost.XGBRegressor(tree_method='hist',
|
||||
num_parallel_tree=kForests,
|
||||
max_depth=kMaxDepth,
|
||||
n_estimators=kRounds)
|
||||
reg.fit(X, y, w)
|
||||
reg.save_model(skl_bin('reg'))
|
||||
reg.save_model(skl_json('reg'))
|
||||
|
||||
|
||||
def generate_logistic_model():
|
||||
print('Logistic')
|
||||
y = np.random.randint(0, 2, size=kRows)
|
||||
assert y.max() == 1 and y.min() == 0
|
||||
|
||||
data = xgboost.DMatrix(X, label=y, weight=w)
|
||||
booster = xgboost.train({'tree_method': 'hist',
|
||||
'num_parallel_tree': kForests,
|
||||
'max_depth': kMaxDepth,
|
||||
'objective': 'binary:logistic'},
|
||||
num_boost_round=kRounds, dtrain=data)
|
||||
booster.save_model(booster_bin('logit'))
|
||||
booster.save_model(booster_json('logit'))
|
||||
|
||||
reg = xgboost.XGBClassifier(tree_method='hist',
|
||||
num_parallel_tree=kForests,
|
||||
max_depth=kMaxDepth,
|
||||
n_estimators=kRounds)
|
||||
reg.fit(X, y, w)
|
||||
reg.save_model(skl_bin('logit'))
|
||||
reg.save_model(skl_json('logit'))
|
||||
|
||||
|
||||
def generate_classification_model():
|
||||
print('Classification')
|
||||
y = np.random.randint(0, kClasses, size=kRows)
|
||||
data = xgboost.DMatrix(X, label=y, weight=w)
|
||||
booster = xgboost.train({'num_class': kClasses,
|
||||
'tree_method': 'hist',
|
||||
'num_parallel_tree': kForests,
|
||||
'max_depth': kMaxDepth},
|
||||
num_boost_round=kRounds, dtrain=data)
|
||||
booster.save_model(booster_bin('cls'))
|
||||
booster.save_model(booster_json('cls'))
|
||||
|
||||
cls = xgboost.XGBClassifier(tree_method='hist',
|
||||
num_parallel_tree=kForests,
|
||||
max_depth=kMaxDepth,
|
||||
n_estimators=kRounds)
|
||||
cls.fit(X, y, w)
|
||||
cls.save_model(skl_bin('cls'))
|
||||
cls.save_model(skl_json('cls'))
|
||||
|
||||
|
||||
def generate_ranking_model():
|
||||
print('Learning to Rank')
|
||||
y = np.random.randint(5, size=kRows)
|
||||
w = np.random.uniform(size=20)
|
||||
g = np.repeat(50, 20)
|
||||
|
||||
data = xgboost.DMatrix(X, y, weight=w)
|
||||
data.set_group(g)
|
||||
booster = xgboost.train({'objective': 'rank:ndcg',
|
||||
'num_parallel_tree': kForests,
|
||||
'tree_method': 'hist',
|
||||
'max_depth': kMaxDepth},
|
||||
num_boost_round=kRounds,
|
||||
dtrain=data)
|
||||
booster.save_model(booster_bin('ltr'))
|
||||
booster.save_model(booster_json('ltr'))
|
||||
|
||||
ranker = xgboost.sklearn.XGBRanker(n_estimators=kRounds,
|
||||
tree_method='hist',
|
||||
objective='rank:ndcg',
|
||||
max_depth=kMaxDepth,
|
||||
num_parallel_tree=kForests)
|
||||
ranker.fit(X, y, g, sample_weight=w)
|
||||
ranker.save_model(skl_bin('ltr'))
|
||||
ranker.save_model(skl_json('ltr'))
|
||||
|
||||
|
||||
def write_versions():
|
||||
versions = {'numpy': np.__version__,
|
||||
'xgboost': version}
|
||||
with open(os.path.join(target_dir, 'version'), 'w') as fd:
|
||||
fd.write(str(versions))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not os.path.exists(target_dir):
|
||||
os.mkdir(target_dir)
|
||||
|
||||
generate_regression_model()
|
||||
generate_logistic_model()
|
||||
generate_classification_model()
|
||||
generate_ranking_model()
|
||||
write_versions()
|
||||
@@ -35,11 +35,16 @@ def captured_output():
|
||||
|
||||
|
||||
class TestBasic(unittest.TestCase):
|
||||
def test_compat(self):
|
||||
from xgboost.compat import lazy_isinstance
|
||||
a = np.array([1, 2, 3])
|
||||
assert lazy_isinstance(a, 'numpy', 'ndarray')
|
||||
assert not lazy_isinstance(a, 'numpy', 'dataframe')
|
||||
|
||||
def test_basic(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
param = {'max_depth': 2, 'eta': 1,
|
||||
'objective': 'binary:logistic'}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import json
|
||||
import testing as tm
|
||||
import pytest
|
||||
import locale
|
||||
|
||||
dpath = 'demo/data/'
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
@@ -284,25 +285,49 @@ class TestModels(unittest.TestCase):
|
||||
self.assertRaises(ValueError, bst.predict, dm1)
|
||||
bst.predict(dm2) # success
|
||||
|
||||
def test_model_binary_io(self):
|
||||
model_path = 'test_model_binary_io.bin'
|
||||
parameters = {'tree_method': 'hist', 'booster': 'gbtree',
|
||||
'scale_pos_weight': '0.5'}
|
||||
X = np.random.random((10, 3))
|
||||
y = np.random.random((10,))
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
bst = xgb.train(parameters, dtrain, num_boost_round=2)
|
||||
bst.save_model(model_path)
|
||||
bst = xgb.Booster(model_file=model_path)
|
||||
os.remove(model_path)
|
||||
config = json.loads(bst.save_config())
|
||||
assert float(config['learner']['objective'][
|
||||
'reg_loss_param']['scale_pos_weight']) == 0.5
|
||||
|
||||
buf = bst.save_raw()
|
||||
from_raw = xgb.Booster()
|
||||
from_raw.load_model(buf)
|
||||
|
||||
buf_from_raw = from_raw.save_raw()
|
||||
assert buf == buf_from_raw
|
||||
|
||||
def test_model_json_io(self):
|
||||
model_path = './model.json'
|
||||
loc = locale.getpreferredencoding(False)
|
||||
model_path = 'test_model_json_io.json'
|
||||
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
|
||||
j_model = json_model(model_path, parameters)
|
||||
assert isinstance(j_model['learner'], dict)
|
||||
|
||||
bst = xgb.Booster(model_file='./model.json')
|
||||
bst = xgb.Booster(model_file=model_path)
|
||||
|
||||
bst.save_model(fname=model_path)
|
||||
with open('./model.json', 'r') as fd:
|
||||
with open(model_path, 'r') as fd:
|
||||
j_model = json.load(fd)
|
||||
assert isinstance(j_model['learner'], dict)
|
||||
|
||||
os.remove(model_path)
|
||||
assert locale.getpreferredencoding(False) == loc
|
||||
|
||||
@pytest.mark.skipif(**tm.no_json_schema())
|
||||
def test_json_schema(self):
|
||||
import jsonschema
|
||||
model_path = './model.json'
|
||||
model_path = 'test_json_schema.json'
|
||||
path = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
doc = os.path.join(path, 'doc', 'model.schema')
|
||||
|
||||
130
tests/python/test_model_compatibility.py
Normal file
130
tests/python/test_model_compatibility.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import xgboost
|
||||
import os
|
||||
import generate_models as gm
|
||||
import json
|
||||
import zipfile
|
||||
import pytest
|
||||
|
||||
|
||||
def run_model_param_check(config):
|
||||
assert config['learner']['learner_model_param']['num_feature'] == str(4)
|
||||
assert config['learner']['learner_train_param']['booster'] == 'gbtree'
|
||||
|
||||
|
||||
def run_booster_check(booster, name):
|
||||
config = json.loads(booster.save_config())
|
||||
run_model_param_check(config)
|
||||
if name.find('cls') != -1:
|
||||
assert (len(booster.get_dump()) == gm.kForests * gm.kRounds *
|
||||
gm.kClasses)
|
||||
assert float(
|
||||
config['learner']['learner_model_param']['base_score']) == 0.5
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'multi:softmax'
|
||||
elif name.find('logit') != -1:
|
||||
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
|
||||
assert config['learner']['learner_model_param']['num_class'] == str(0)
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'binary:logistic'
|
||||
elif name.find('ltr') != -1:
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'rank:ndcg'
|
||||
else:
|
||||
assert name.find('reg') != -1
|
||||
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
|
||||
assert float(
|
||||
config['learner']['learner_model_param']['base_score']) == 0.5
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'reg:squarederror'
|
||||
|
||||
|
||||
def run_scikit_model_check(name, path):
|
||||
if name.find('reg') != -1:
|
||||
reg = xgboost.XGBRegressor()
|
||||
reg.load_model(path)
|
||||
config = json.loads(reg.get_booster().save_config())
|
||||
if name.find('0.90') != -1:
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'reg:linear'
|
||||
else:
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'reg:squarederror'
|
||||
assert (len(reg.get_booster().get_dump()) ==
|
||||
gm.kRounds * gm.kForests)
|
||||
run_model_param_check(config)
|
||||
elif name.find('cls') != -1:
|
||||
cls = xgboost.XGBClassifier()
|
||||
cls.load_model(path)
|
||||
if name.find('0.90') == -1:
|
||||
assert len(cls.classes_) == gm.kClasses
|
||||
assert len(cls._le.classes_) == gm.kClasses
|
||||
assert cls.n_classes_ == gm.kClasses
|
||||
assert (len(cls.get_booster().get_dump()) ==
|
||||
gm.kRounds * gm.kForests * gm.kClasses), path
|
||||
config = json.loads(cls.get_booster().save_config())
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'multi:softprob', path
|
||||
run_model_param_check(config)
|
||||
elif name.find('ltr') != -1:
|
||||
ltr = xgboost.XGBRanker()
|
||||
ltr.load_model(path)
|
||||
assert (len(ltr.get_booster().get_dump()) ==
|
||||
gm.kRounds * gm.kForests)
|
||||
config = json.loads(ltr.get_booster().save_config())
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'rank:ndcg'
|
||||
run_model_param_check(config)
|
||||
elif name.find('logit') != -1:
|
||||
logit = xgboost.XGBClassifier()
|
||||
logit.load_model(path)
|
||||
assert (len(logit.get_booster().get_dump()) ==
|
||||
gm.kRounds * gm.kForests)
|
||||
config = json.loads(logit.get_booster().save_config())
|
||||
assert config['learner']['learner_train_param'][
|
||||
'objective'] == 'binary:logistic'
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.ci
|
||||
def test_model_compatibility():
|
||||
'''Test model compatibility, can only be run on CI as others don't
|
||||
have the credentials.
|
||||
|
||||
'''
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
path = os.path.join(path, 'models')
|
||||
try:
|
||||
import boto3
|
||||
import botocore
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
'Skiping compatibility tests as boto3 is not installed.')
|
||||
|
||||
try:
|
||||
s3_bucket = boto3.resource('s3').Bucket('xgboost-ci-jenkins-artifacts')
|
||||
zip_path = 'xgboost_model_compatibility_test.zip'
|
||||
s3_bucket.download_file(zip_path, zip_path)
|
||||
except botocore.exceptions.NoCredentialsError:
|
||||
pytest.skip(
|
||||
'Skiping compatibility tests as running on non-CI environment.')
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'r') as z:
|
||||
z.extractall(path)
|
||||
|
||||
models = [
|
||||
os.path.join(root, f) for root, subdir, files in os.walk(path)
|
||||
for f in files
|
||||
if f != 'version'
|
||||
]
|
||||
assert models
|
||||
|
||||
for path in models:
|
||||
name = os.path.basename(path)
|
||||
if name.startswith('xgboost-'):
|
||||
booster = xgboost.Booster(model_file=path)
|
||||
run_booster_check(booster, name)
|
||||
elif name.startswith('xgboost_scikit'):
|
||||
run_scikit_model_check(name, path)
|
||||
else:
|
||||
assert False
|
||||
@@ -115,7 +115,6 @@ class TestRanking(unittest.TestCase):
|
||||
# model training parameters
|
||||
cls.params = {'objective': 'rank:pairwise',
|
||||
'booster': 'gbtree',
|
||||
'silent': 0,
|
||||
'eval_metric': ['ndcg']
|
||||
}
|
||||
|
||||
@@ -143,7 +142,7 @@ class TestRanking(unittest.TestCase):
|
||||
Test cross-validation with a group specified
|
||||
"""
|
||||
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, nfold=10, as_pandas=False)
|
||||
early_stopping_rounds=10, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
self.assertSetEqual(set(cv.keys()), {'test-ndcg-mean', 'train-ndcg-mean', 'test-ndcg-std', 'train-ndcg-std'},
|
||||
"CV results dict key mismatch")
|
||||
@@ -153,7 +152,8 @@ class TestRanking(unittest.TestCase):
|
||||
Test cross-validation with a group specified
|
||||
"""
|
||||
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, shuffle=False, nfold=10, as_pandas=False)
|
||||
early_stopping_rounds=10, shuffle=False, nfold=10,
|
||||
as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == 4
|
||||
|
||||
|
||||
@@ -10,18 +10,20 @@ if sys.platform.startswith("win"):
|
||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||
|
||||
try:
|
||||
from distributed.utils_test import client, loop, cluster_fixture
|
||||
from distributed import LocalCluster, Client
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
from xgboost.dask import DaskDMatrix
|
||||
except ImportError:
|
||||
client = None
|
||||
loop = None
|
||||
cluster_fixture = None
|
||||
pass
|
||||
LocalCluster = None
|
||||
Client = None
|
||||
dd = None
|
||||
da = None
|
||||
DaskDMatrix = None
|
||||
|
||||
kRows = 1000
|
||||
kCols = 10
|
||||
kWorkers = 5
|
||||
|
||||
|
||||
def generate_array():
|
||||
@@ -31,97 +33,106 @@ def generate_array():
|
||||
return X, y
|
||||
|
||||
|
||||
def test_from_dask_dataframe(client):
|
||||
X, y = generate_array()
|
||||
def test_from_dask_dataframe():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
|
||||
X = dd.from_dask_array(X)
|
||||
y = dd.from_dask_array(y)
|
||||
X = dd.from_dask_array(X)
|
||||
y = dd.from_dask_array(y)
|
||||
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2)['booster']
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2)['booster']
|
||||
|
||||
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert prediction.shape[0] == kRows
|
||||
assert prediction.ndim == 1
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# evals_result is not supported in dask interface.
|
||||
xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||
|
||||
prediction = prediction.compute() # force prediction to be computed
|
||||
with pytest.raises(ValueError):
|
||||
# evals_result is not supported in dask interface.
|
||||
xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||
# force prediction to be computed
|
||||
prediction = prediction.compute()
|
||||
|
||||
|
||||
def test_from_dask_array(client):
|
||||
X, y = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
# results is {'booster': Booster, 'history': {...}}
|
||||
result = xgb.dask.train(client, {}, dtrain)
|
||||
def test_from_dask_array():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
# results is {'booster': Booster, 'history': {...}}
|
||||
result = xgb.dask.train(client, {}, dtrain)
|
||||
|
||||
prediction = xgb.dask.predict(client, result, dtrain)
|
||||
assert prediction.shape[0] == kRows
|
||||
prediction = xgb.dask.predict(client, result, dtrain)
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
|
||||
prediction = prediction.compute() # force prediction to be computed
|
||||
assert isinstance(prediction, da.Array)
|
||||
# force prediction to be computed
|
||||
prediction = prediction.compute()
|
||||
|
||||
|
||||
def test_regressor(client):
|
||||
X, y = generate_array()
|
||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.client = client
|
||||
regressor.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = regressor.predict(X)
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.client = client
|
||||
regressor.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = regressor.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
history = regressor.evals_result()
|
||||
history = regressor.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||
assert len(history['validation_0']['rmse']) == 2
|
||||
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||
assert len(history['validation_0']['rmse']) == 2
|
||||
|
||||
|
||||
def test_classifier(client):
|
||||
X, y = generate_array()
|
||||
y = (y * 10).astype(np.int32)
|
||||
classifier = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2)
|
||||
classifier.client = client
|
||||
classifier.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = classifier.predict(X)
|
||||
def test_dask_classifier():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
y = (y * 10).astype(np.int32)
|
||||
classifier = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2)
|
||||
classifier.client = client
|
||||
classifier.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = classifier.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
history = classifier.evals_result()
|
||||
history = classifier.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history.keys())[0] == 'validation_0'
|
||||
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||
assert len(list(history['validation_0'])) == 1
|
||||
assert len(history['validation_0']['merror']) == 2
|
||||
assert list(history.keys())[0] == 'validation_0'
|
||||
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||
assert len(list(history['validation_0'])) == 1
|
||||
assert len(history['validation_0']['merror']) == 2
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
assert classifier.n_classes_ == 10
|
||||
|
||||
# Test with dataframe.
|
||||
X_d = dd.from_dask_array(X)
|
||||
y_d = dd.from_dask_array(y)
|
||||
classifier.fit(X_d, y_d)
|
||||
# Test with dataframe.
|
||||
X_d = dd.from_dask_array(X)
|
||||
y_d = dd.from_dask_array(y)
|
||||
classifier.fit(X_d, y_d)
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
prediction = classifier.predict(X_d)
|
||||
assert classifier.n_classes_ == 10
|
||||
prediction = classifier.predict(X_d)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
|
||||
def run_empty_dmatrix(client, parameters):
|
||||
@@ -164,11 +175,15 @@ def run_empty_dmatrix(client, parameters):
|
||||
# No test for Exact, as empty DMatrix handling are mostly for distributed
|
||||
# environment and Exact doesn't support it.
|
||||
|
||||
def test_empty_dmatrix_hist(client):
|
||||
parameters = {'tree_method': 'hist'}
|
||||
run_empty_dmatrix(client, parameters)
|
||||
def test_empty_dmatrix_hist():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'hist'}
|
||||
run_empty_dmatrix(client, parameters)
|
||||
|
||||
|
||||
def test_empty_dmatrix_approx(client):
|
||||
parameters = {'tree_method': 'approx'}
|
||||
run_empty_dmatrix(client, parameters)
|
||||
def test_empty_dmatrix_approx():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'approx'}
|
||||
run_empty_dmatrix(client, parameters)
|
||||
|
||||
@@ -34,7 +34,8 @@ def test_binary_classification():
|
||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
|
||||
for train_index, test_index in kf.split(X, y):
|
||||
xgb_model = cls(random_state=42).fit(X[train_index], y[train_index])
|
||||
clf = cls(random_state=42)
|
||||
xgb_model = clf.fit(X[train_index], y[train_index], eval_metric=['auc', 'logloss'])
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
labels = y[test_index]
|
||||
err = sum(1 for i in range(len(preds))
|
||||
@@ -490,6 +491,13 @@ def test_kwargs():
|
||||
assert clf.get_params()['n_estimators'] == 1000
|
||||
|
||||
|
||||
def test_kwargs_error():
|
||||
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
||||
with pytest.raises(TypeError):
|
||||
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
||||
assert isinstance(clf, xgb.XGBClassifier)
|
||||
|
||||
|
||||
def test_kwargs_grid_search():
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
from sklearn import datasets
|
||||
@@ -510,13 +518,6 @@ def test_kwargs_grid_search():
|
||||
assert len(means) == len(set(means))
|
||||
|
||||
|
||||
def test_kwargs_error():
|
||||
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
||||
with pytest.raises(TypeError):
|
||||
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
||||
assert isinstance(clf, xgb.XGBClassifier)
|
||||
|
||||
|
||||
def test_sklearn_clone():
|
||||
from sklearn.base import clone
|
||||
|
||||
@@ -525,6 +526,17 @@ def test_sklearn_clone():
|
||||
clone(clf)
|
||||
|
||||
|
||||
def test_sklearn_get_default_params():
|
||||
from sklearn.datasets import load_digits
|
||||
digits_2class = load_digits(2)
|
||||
X = digits_2class['data']
|
||||
y = digits_2class['target']
|
||||
cls = xgb.XGBClassifier()
|
||||
assert cls.get_params()['base_score'] is None
|
||||
cls.fit(X[:4, ...], y[:4, ...])
|
||||
assert cls.get_params()['base_score'] is not None
|
||||
|
||||
|
||||
def test_validation_weights_xgbmodel():
|
||||
from sklearn.datasets import make_hastie_10_2
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding: utf-8
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
from xgboost.compat import CUDF_INSTALLED, DASK_INSTALLED
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ def no_pandas():
|
||||
|
||||
|
||||
def no_dt():
|
||||
return {'condition': not DT_INSTALLED,
|
||||
import importlib.util
|
||||
spec = importlib.util.find_spec('datatable')
|
||||
return {'condition': spec is None,
|
||||
'reason': 'Datatable is not installed.'}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user