Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74e2f652de | ||
|
|
e02fff53f2 | ||
|
|
fcb2efbadd | ||
|
|
f4621f09c7 | ||
|
|
bf1b2cbfa2 | ||
|
|
d90e7b3117 | ||
|
|
088c43d666 | ||
|
|
69fc8a632f | ||
|
|
213f4fa45a | ||
|
|
5ca21f252a | ||
|
|
eeb67c3d52 | ||
|
|
ed37fdb9c9 | ||
|
|
e7e522fb06 | ||
|
|
8e39a675be | ||
|
|
7f542d2198 | ||
|
|
c8d32102fb |
@@ -1,5 +1,5 @@
|
|||||||
cmake_minimum_required(VERSION 3.12)
|
cmake_minimum_required(VERSION 3.12)
|
||||||
project(xgboost LANGUAGES CXX C VERSION 1.0.0)
|
project(xgboost LANGUAGES CXX C VERSION 1.0.1)
|
||||||
include(cmake/Utils.cmake)
|
include(cmake/Utils.cmake)
|
||||||
list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules")
|
list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules")
|
||||||
cmake_policy(SET CMP0022 NEW)
|
cmake_policy(SET CMP0022 NEW)
|
||||||
@@ -49,7 +49,7 @@ option(USE_SANITIZER "Use santizer flags" OFF)
|
|||||||
option(SANITIZER_PATH "Path to sanitizes.")
|
option(SANITIZER_PATH "Path to sanitizes.")
|
||||||
set(ENABLED_SANITIZERS "address" "leak" CACHE STRING
|
set(ENABLED_SANITIZERS "address" "leak" CACHE STRING
|
||||||
"Semicolon separated list of sanitizer names. E.g 'address;leak'. Supported sanitizers are
|
"Semicolon separated list of sanitizer names. E.g 'address;leak'. Supported sanitizers are
|
||||||
address, leak and thread.")
|
address, leak, undefined and thread.")
|
||||||
## Plugins
|
## Plugins
|
||||||
option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
|
option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
|
||||||
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
|
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
|
||||||
|
|||||||
1
Jenkinsfile
vendored
1
Jenkinsfile
vendored
@@ -273,6 +273,7 @@ def TestPythonCPU() {
|
|||||||
def docker_binary = "docker"
|
def docker_binary = "docker"
|
||||||
sh """
|
sh """
|
||||||
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/test_python.sh cpu
|
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/test_python.sh cpu
|
||||||
|
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/test_python.sh cpu-py35
|
||||||
"""
|
"""
|
||||||
deleteDir()
|
deleteDir()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -139,6 +139,8 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
|||||||
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
|
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
|
||||||
#' prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
|
#' prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
|
||||||
#' or predinteraction flags is TRUE.
|
#' or predinteraction flags is TRUE.
|
||||||
|
#' @param training whether is the prediction result used for training. For dart booster,
|
||||||
|
#' training predicting will perform dropout.
|
||||||
#' @param ... Parameters passed to \code{predict.xgb.Booster}
|
#' @param ... Parameters passed to \code{predict.xgb.Booster}
|
||||||
#'
|
#'
|
||||||
#' @details
|
#' @details
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ It will use all the trees by default (\code{NULL} value).}
|
|||||||
prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
|
prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
|
||||||
or predinteraction flags is TRUE.}
|
or predinteraction flags is TRUE.}
|
||||||
|
|
||||||
|
\item{training}{whether is the prediction result used for training. For dart booster,
|
||||||
|
training predicting will perform dropout.}
|
||||||
|
|
||||||
\item{...}{Parameters passed to \code{predict.xgb.Booster}}
|
\item{...}{Parameters passed to \code{predict.xgb.Booster}}
|
||||||
}
|
}
|
||||||
\value{
|
\value{
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ num_round <- 2
|
|||||||
test_that("custom objective works", {
|
test_that("custom objective works", {
|
||||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||||
expect_equal(class(bst), "xgb.Booster")
|
expect_equal(class(bst), "xgb.Booster")
|
||||||
expect_equal(length(bst$raw), 1100)
|
|
||||||
expect_false(is.null(bst$evaluation_log))
|
expect_false(is.null(bst$evaluation_log))
|
||||||
expect_false(is.null(bst$evaluation_log$eval_error))
|
expect_false(is.null(bst$evaluation_log$eval_error))
|
||||||
expect_lt(bst$evaluation_log[num_round, eval_error], 0.03)
|
expect_lt(bst$evaluation_log[num_round, eval_error], 0.03)
|
||||||
@@ -58,5 +57,4 @@ test_that("custom objective using DMatrix attr works", {
|
|||||||
param$objective = logregobjattr
|
param$objective = logregobjattr
|
||||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||||
expect_equal(class(bst), "xgb.Booster")
|
expect_equal(class(bst), "xgb.Booster")
|
||||||
expect_equal(length(bst$raw), 1100)
|
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
@xgboost_VERSION_MAJOR@.@xgboost_VERSION_MINOR@.@xgboost_VERSION_PATCH@-SNAPSHOT
|
@xgboost_VERSION_MAJOR@.@xgboost_VERSION_MINOR@.@xgboost_VERSION_PATCH@
|
||||||
|
|||||||
@@ -195,12 +195,22 @@
|
|||||||
"properties": {
|
"properties": {
|
||||||
"version": {
|
"version": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"const": [
|
"items": [
|
||||||
1,
|
{
|
||||||
0,
|
"type": "number",
|
||||||
0
|
"const": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0
|
||||||
|
}
|
||||||
],
|
],
|
||||||
"additionalItems": false
|
"minItems": 3,
|
||||||
|
"maxItems": 3
|
||||||
},
|
},
|
||||||
"learner": {
|
"learner": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
79
doc/python/convert_090to100.py
Normal file
79
doc/python/convert_090to100.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
'''This is a simple script that converts a pickled XGBoost
|
||||||
|
Scikit-Learn interface object from 0.90 to a native model. Pickle
|
||||||
|
format is not stable as it's a direct serialization of Python object.
|
||||||
|
We advice not to use it when stability is needed.
|
||||||
|
|
||||||
|
'''
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import xgboost
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def save_label_encoder(le):
|
||||||
|
'''Save the label encoder in XGBClassifier'''
|
||||||
|
meta = dict()
|
||||||
|
for k, v in le.__dict__.items():
|
||||||
|
if isinstance(v, np.ndarray):
|
||||||
|
meta[k] = v.tolist()
|
||||||
|
else:
|
||||||
|
meta[k] = v
|
||||||
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
def xgboost_skl_90to100(skl_model):
|
||||||
|
'''Extract the model and related metadata in SKL model.'''
|
||||||
|
model = {}
|
||||||
|
with open(skl_model, 'rb') as fd:
|
||||||
|
old = pickle.load(fd)
|
||||||
|
if not isinstance(old, xgboost.XGBModel):
|
||||||
|
raise TypeError(
|
||||||
|
'The script only handes Scikit-Learn interface object')
|
||||||
|
|
||||||
|
# Save Scikit-Learn specific Python attributes into a JSON document.
|
||||||
|
for k, v in old.__dict__.items():
|
||||||
|
if k == '_le':
|
||||||
|
model[k] = save_label_encoder(v)
|
||||||
|
elif k == 'classes_':
|
||||||
|
model[k] = v.tolist()
|
||||||
|
elif k == '_Booster':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
json.dumps({k: v})
|
||||||
|
model[k] = v
|
||||||
|
except TypeError:
|
||||||
|
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
||||||
|
booster = old.get_booster()
|
||||||
|
# Store the JSON serialization as an attribute
|
||||||
|
booster.set_attr(scikit_learn=json.dumps(model))
|
||||||
|
|
||||||
|
# Save it into a native model.
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin'
|
||||||
|
if os.path.exists(path):
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
booster.save_model(path)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version'
|
||||||
|
' that generates this pickle.')
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=('A simple script to convert pickle generated by'
|
||||||
|
' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).'))
|
||||||
|
parser.add_argument(
|
||||||
|
'--old-pickle',
|
||||||
|
type=str,
|
||||||
|
help='Path to old pickle file of Scikit-Learn interface object. '
|
||||||
|
'Will output a native model converted from this pickle file',
|
||||||
|
required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
xgboost_skl_90to100(args.old_pickle)
|
||||||
@@ -91,7 +91,12 @@ Loading pickled file from different version of XGBoost
|
|||||||
|
|
||||||
As noted, pickled model is neither portable nor stable, but in some cases the pickled
|
As noted, pickled model is neither portable nor stable, but in some cases the pickled
|
||||||
models are valuable. One way to restore it in the future is to load it back with that
|
models are valuable. One way to restore it in the future is to load it back with that
|
||||||
specific version of Python and XGBoost, export the model by calling `save_model`.
|
specific version of Python and XGBoost, export the model by calling `save_model`. To help
|
||||||
|
easing the mitigation, we created a simple script for converting pickled XGBoost 0.90
|
||||||
|
Scikit-Learn interface object to XGBoost 1.0.0 native model. Please note that the script
|
||||||
|
suits simple use cases, and it's advised not to use pickle when stability is needed.
|
||||||
|
It's located in ``xgboost/doc/python`` with the name ``convert_090to100.py``. See
|
||||||
|
comments in the script for more details.
|
||||||
|
|
||||||
********************************************************
|
********************************************************
|
||||||
Saving and Loading the internal parameters configuration
|
Saving and Loading the internal parameters configuration
|
||||||
@@ -190,7 +195,9 @@ You can load it back to the model generated by same version of XGBoost by:
|
|||||||
|
|
||||||
bst.load_config(config)
|
bst.load_config(config)
|
||||||
|
|
||||||
This way users can study the internal representation more closely.
|
This way users can study the internal representation more closely. Please note that some
|
||||||
|
JSON generators make use of locale dependent floating point serialization methods, which
|
||||||
|
is not supported by XGBoost.
|
||||||
|
|
||||||
************
|
************
|
||||||
Future Plans
|
Future Plans
|
||||||
|
|||||||
@@ -208,6 +208,8 @@ struct LearnerModelParam {
|
|||||||
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
||||||
// this one as an immutable copy.
|
// this one as an immutable copy.
|
||||||
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin);
|
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin);
|
||||||
|
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
|
||||||
|
bool Initialized() const { return num_feature != 0; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@@ -6,6 +6,6 @@
|
|||||||
|
|
||||||
#define XGBOOST_VER_MAJOR 1
|
#define XGBOOST_VER_MAJOR 1
|
||||||
#define XGBOOST_VER_MINOR 0
|
#define XGBOOST_VER_MINOR 0
|
||||||
#define XGBOOST_VER_PATCH 0
|
#define XGBOOST_VER_PATCH 1
|
||||||
|
|
||||||
#endif // XGBOOST_VERSION_CONFIG_H_
|
#endif // XGBOOST_VERSION_CONFIG_H_
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<name>XGBoost JVM Package</name>
|
<name>XGBoost JVM Package</name>
|
||||||
<description>JVM Package for XGBoost</description>
|
<description>JVM Package for XGBoost</description>
|
||||||
@@ -37,6 +37,7 @@
|
|||||||
<spark.version>2.4.3</spark.version>
|
<spark.version>2.4.3</spark.version>
|
||||||
<scala.version>2.12.8</scala.version>
|
<scala.version>2.12.8</scala.version>
|
||||||
<scala.binary.version>2.12</scala.binary.version>
|
<scala.binary.version>2.12</scala.binary.version>
|
||||||
|
<hadoop.version>2.7.3</hadoop.version>
|
||||||
</properties>
|
</properties>
|
||||||
<repositories>
|
<repositories>
|
||||||
<repository>
|
<repository>
|
||||||
|
|||||||
@@ -6,10 +6,10 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j-example_2.12</artifactId>
|
<artifactId>xgboost4j-example_2.12</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
@@ -26,7 +26,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
@@ -37,7 +37,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
|
|||||||
@@ -6,10 +6,10 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j-flink_2.12</artifactId>
|
<artifactId>xgboost4j-flink_2.12</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
<plugin>
|
||||||
@@ -26,7 +26,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j-spark_2.12</artifactId>
|
<artifactId>xgboost4j-spark_2.12</artifactId>
|
||||||
<build>
|
<build>
|
||||||
@@ -24,7 +24,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
|
|||||||
@@ -1,164 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (c) 2014 by Contributors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.Booster
|
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
|
|
||||||
import org.apache.commons.logging.LogFactory
|
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
|
||||||
import org.apache.spark.SparkContext
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
|
|
||||||
* the job can restart training from a saved checkpoints instead of from scratch. This class
|
|
||||||
* provides interface and helper methods for the checkpoint functionality.
|
|
||||||
*
|
|
||||||
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
|
|
||||||
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
|
|
||||||
* for every a few iterations.
|
|
||||||
*
|
|
||||||
* @param sc the sparkContext object
|
|
||||||
* @param checkpointPath the hdfs path to store checkpoints
|
|
||||||
*/
|
|
||||||
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
|
|
||||||
private val logger = LogFactory.getLog("XGBoostSpark")
|
|
||||||
private val modelSuffix = ".model"
|
|
||||||
|
|
||||||
private def getPath(version: Int) = {
|
|
||||||
s"$checkpointPath/$version$modelSuffix"
|
|
||||||
}
|
|
||||||
|
|
||||||
private def getExistingVersions: Seq[Int] = {
|
|
||||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
|
||||||
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
|
|
||||||
Seq()
|
|
||||||
} else {
|
|
||||||
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
|
|
||||||
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def cleanPath(): Unit = {
|
|
||||||
if (checkpointPath != "") {
|
|
||||||
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load existing checkpoint with the highest version as a Booster object
|
|
||||||
*
|
|
||||||
* @return the booster with the highest version, null if no checkpoints available.
|
|
||||||
*/
|
|
||||||
private[spark] def loadCheckpointAsBooster: Booster = {
|
|
||||||
val versions = getExistingVersions
|
|
||||||
if (versions.nonEmpty) {
|
|
||||||
val version = versions.max
|
|
||||||
val fullPath = getPath(version)
|
|
||||||
val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
|
|
||||||
logger.info(s"Start training from previous booster at $fullPath")
|
|
||||||
val booster = SXGBoost.loadModel(inputStream)
|
|
||||||
booster.booster.setVersion(version)
|
|
||||||
booster
|
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up all previous checkpoints and save a new checkpoint
|
|
||||||
*
|
|
||||||
* @param checkpoint the checkpoint to save as an XGBoostModel
|
|
||||||
*/
|
|
||||||
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
|
|
||||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
|
||||||
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
|
||||||
val fullPath = getPath(checkpoint.getVersion)
|
|
||||||
val outputStream = fs.create(new Path(fullPath), true)
|
|
||||||
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
|
|
||||||
checkpoint.saveModel(outputStream)
|
|
||||||
prevModelPaths.foreach(path => fs.delete(path, true))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up checkpoint boosters with version higher than or equal to the round.
|
|
||||||
*
|
|
||||||
* @param round the number of rounds in the current training job
|
|
||||||
*/
|
|
||||||
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
|
|
||||||
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
|
|
||||||
higherVersions.foreach { version =>
|
|
||||||
val fs = FileSystem.get(sc.hadoopConfiguration)
|
|
||||||
fs.delete(new Path(getPath(version)), true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
|
|
||||||
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
|
|
||||||
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
|
|
||||||
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
|
|
||||||
* and the method returns Seq(round)
|
|
||||||
*
|
|
||||||
* @param checkpointInterval Period (in iterations) between checkpoints.
|
|
||||||
* @param round the total number of rounds for the training
|
|
||||||
* @return a seq of integers, each represent the index of round to save the checkpoints
|
|
||||||
*/
|
|
||||||
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
|
|
||||||
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
|
|
||||||
val prevRounds = getExistingVersions.map(_ / 2)
|
|
||||||
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
|
|
||||||
(firstCheckpointRound until round by checkpointInterval) :+ round
|
|
||||||
} else if (checkpointInterval <= 0) {
|
|
||||||
Seq(round)
|
|
||||||
} else {
|
|
||||||
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
object CheckpointManager {
|
|
||||||
|
|
||||||
case class CheckpointParam(
|
|
||||||
checkpointPath: String,
|
|
||||||
checkpointInterval: Int,
|
|
||||||
skipCleanCheckpoint: Boolean)
|
|
||||||
|
|
||||||
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
|
|
||||||
val checkpointPath: String = params.get("checkpoint_path") match {
|
|
||||||
case None => ""
|
|
||||||
case Some(path: String) => path
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
|
||||||
" an instance of String.")
|
|
||||||
}
|
|
||||||
|
|
||||||
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
|
||||||
case None => 0
|
|
||||||
case Some(freq: Int) => freq
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
|
||||||
" an instance of Int.")
|
|
||||||
}
|
|
||||||
|
|
||||||
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
|
||||||
case None => false
|
|
||||||
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
|
||||||
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
|
||||||
" an instance of Boolean")
|
|
||||||
}
|
|
||||||
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -25,12 +25,13 @@ import scala.collection.JavaConverters._
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
|
||||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||||
|
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||||
import org.apache.commons.io.FileUtils
|
import org.apache.commons.io.FileUtils
|
||||||
import org.apache.commons.logging.LogFactory
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.apache.hadoop.fs.FileSystem
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
|
||||||
@@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see
|
|||||||
|
|
||||||
private[this] case class XGBoostExecutionParams(
|
private[this] case class XGBoostExecutionParams(
|
||||||
numWorkers: Int,
|
numWorkers: Int,
|
||||||
round: Int,
|
numRounds: Int,
|
||||||
useExternalMemory: Boolean,
|
useExternalMemory: Boolean,
|
||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
@@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams(
|
|||||||
allowNonZeroForMissing: Boolean,
|
allowNonZeroForMissing: Boolean,
|
||||||
trackerConf: TrackerConf,
|
trackerConf: TrackerConf,
|
||||||
timeoutRequestWorkers: Long,
|
timeoutRequestWorkers: Long,
|
||||||
checkpointParam: CheckpointParam,
|
checkpointParam: Option[ExternalCheckpointParams],
|
||||||
xgbInputParams: XGBoostExecutionInputParams,
|
xgbInputParams: XGBoostExecutionInputParams,
|
||||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||||
cacheTrainingSet: Boolean) {
|
cacheTrainingSet: Boolean) {
|
||||||
@@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
.getOrElse("allow_non_zero_for_missing", false)
|
.getOrElse("allow_non_zero_for_missing", false)
|
||||||
.asInstanceOf[Boolean]
|
.asInstanceOf[Boolean]
|
||||||
validateSparkSslConf
|
validateSparkSslConf
|
||||||
|
|
||||||
if (overridedParams.contains("tree_method")) {
|
if (overridedParams.contains("tree_method")) {
|
||||||
require(overridedParams("tree_method") == "hist" ||
|
require(overridedParams("tree_method") == "hist" ||
|
||||||
overridedParams("tree_method") == "approx" ||
|
overridedParams("tree_method") == "approx" ||
|
||||||
@@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
" an instance of Long.")
|
" an instance of Long.")
|
||||||
}
|
}
|
||||||
val checkpointParam =
|
val checkpointParam =
|
||||||
CheckpointManager.extractParams(overridedParams)
|
ExternalCheckpointParams.extractParams(overridedParams)
|
||||||
|
|
||||||
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
|
||||||
.asInstanceOf[Double]
|
.asInstanceOf[Double]
|
||||||
@@ -339,11 +339,9 @@ object XGBoost extends Serializable {
|
|||||||
watches: Watches,
|
watches: Watches,
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
round: Int,
|
|
||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
|
||||||
|
|
||||||
// to workaround the empty partitions in training dataset,
|
// to workaround the empty partitions in training dataset,
|
||||||
// this might not be the best efficient implementation, see
|
// this might not be the best efficient implementation, see
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
@@ -357,14 +355,23 @@ object XGBoost extends Serializable {
|
|||||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||||
|
val numRounds = xgbExecutionParam.numRounds
|
||||||
|
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
|
||||||
try {
|
try {
|
||||||
Rabit.init(rabitEnv)
|
Rabit.init(rabitEnv)
|
||||||
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
|
||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
|
||||||
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round,
|
val externalCheckpointParams = xgbExecutionParam.checkpointParam
|
||||||
|
val booster = if (makeCheckpoint) {
|
||||||
|
SXGBoost.trainAndSaveCheckpoint(
|
||||||
|
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||||
|
watches.toMap, metrics, obj, eval,
|
||||||
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
|
||||||
|
} else {
|
||||||
|
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
|
||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||||
|
}
|
||||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||||
} catch {
|
} catch {
|
||||||
case xgbException: XGBoostError =>
|
case xgbException: XGBoostError =>
|
||||||
@@ -437,7 +444,6 @@ object XGBoost extends Serializable {
|
|||||||
trainingData: RDD[XGBLabeledPoint],
|
trainingData: RDD[XGBLabeledPoint],
|
||||||
xgbExecutionParams: XGBoostExecutionParams,
|
xgbExecutionParams: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
checkpointRound: Int,
|
|
||||||
prevBooster: Booster,
|
prevBooster: Booster,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
@@ -446,8 +452,8 @@ object XGBoost extends Serializable {
|
|||||||
processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
processMissingValues(labeledPoints, xgbExecutionParams.missing,
|
||||||
xgbExecutionParams.allowNonZeroForMissing),
|
xgbExecutionParams.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
xgbExecutionParams.eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
|
||||||
@@ -459,8 +465,8 @@ object XGBoost extends Serializable {
|
|||||||
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
|
||||||
},
|
},
|
||||||
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
getCacheDirName(xgbExecutionParams.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
|
||||||
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
|
xgbExecutionParams.eval, prevBooster)
|
||||||
}.cache()
|
}.cache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -469,7 +475,6 @@ object XGBoost extends Serializable {
|
|||||||
trainingData: RDD[Array[XGBLabeledPoint]],
|
trainingData: RDD[Array[XGBLabeledPoint]],
|
||||||
xgbExecutionParam: XGBoostExecutionParams,
|
xgbExecutionParam: XGBoostExecutionParams,
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
checkpointRound: Int,
|
|
||||||
prevBooster: Booster,
|
prevBooster: Booster,
|
||||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
if (evalSetsMap.isEmpty) {
|
if (evalSetsMap.isEmpty) {
|
||||||
@@ -478,7 +483,7 @@ object XGBoost extends Serializable {
|
|||||||
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
|
||||||
xgbExecutionParam.allowNonZeroForMissing),
|
xgbExecutionParam.allowNonZeroForMissing),
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||||
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
|
||||||
}).cache()
|
}).cache()
|
||||||
} else {
|
} else {
|
||||||
@@ -490,7 +495,7 @@ object XGBoost extends Serializable {
|
|||||||
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
|
||||||
},
|
},
|
||||||
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
getCacheDirName(xgbExecutionParam.useExternalMemory))
|
||||||
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
|
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
|
||||||
xgbExecutionParam.obj,
|
xgbExecutionParam.obj,
|
||||||
xgbExecutionParam.eval,
|
xgbExecutionParam.eval,
|
||||||
prevBooster)
|
prevBooster)
|
||||||
@@ -529,33 +534,30 @@ object XGBoost extends Serializable {
|
|||||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
||||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
|
||||||
val sc = trainingData.sparkContext
|
val sc = trainingData.sparkContext
|
||||||
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
|
||||||
checkpointPath)
|
|
||||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
|
|
||||||
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
||||||
hasGroup, xgbExecParams.numWorkers)
|
hasGroup, xgbExecParams.numWorkers)
|
||||||
var prevBooster = checkpointManager.loadCheckpointAsBooster
|
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
|
||||||
|
val checkpointManager = new ExternalCheckpointManager(
|
||||||
|
checkpointParam.checkpointPath,
|
||||||
|
FileSystem.get(sc.hadoopConfiguration))
|
||||||
|
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
|
||||||
|
checkpointManager.loadCheckpointAsScalaBooster()
|
||||||
|
}.orNull
|
||||||
try {
|
try {
|
||||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||||
val producedBooster = checkpointManager.getCheckpointRounds(
|
|
||||||
xgbExecParams.checkpointParam.checkpointInterval,
|
|
||||||
xgbExecParams.round).map {
|
|
||||||
checkpointRound: Int =>
|
|
||||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||||
try {
|
val (booster, metrics) = try {
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||||
xgbExecParams.timeoutRequestWorkers,
|
xgbExecParams.timeoutRequestWorkers,
|
||||||
xgbExecParams.numWorkers)
|
xgbExecParams.numWorkers)
|
||||||
|
val rabitEnv = tracker.getWorkerEnvs
|
||||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
|
||||||
val boostersAndMetrics = if (hasGroup) {
|
val boostersAndMetrics = if (hasGroup) {
|
||||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams,
|
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
|
||||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
evalSetsMap)
|
||||||
} else {
|
} else {
|
||||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams,
|
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
prevBooster, evalSetsMap)
|
||||||
}
|
}
|
||||||
val sparkJobThread = new Thread() {
|
val sparkJobThread = new Thread() {
|
||||||
override def run() {
|
override def run() {
|
||||||
@@ -569,20 +571,21 @@ object XGBoost extends Serializable {
|
|||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
|
||||||
boostersAndMetrics, sparkJobThread)
|
boostersAndMetrics, sparkJobThread)
|
||||||
if (checkpointRound < xgbExecParams.round) {
|
|
||||||
prevBooster = booster
|
|
||||||
checkpointManager.updateCheckpoint(prevBooster)
|
|
||||||
}
|
|
||||||
(booster, metrics)
|
(booster, metrics)
|
||||||
} finally {
|
} finally {
|
||||||
tracker.stop()
|
tracker.stop()
|
||||||
}
|
}
|
||||||
}.last
|
|
||||||
// we should delete the checkpoint directory after a successful training
|
// we should delete the checkpoint directory after a successful training
|
||||||
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
|
xgbExecParams.checkpointParam.foreach {
|
||||||
|
cpParam =>
|
||||||
|
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
|
||||||
|
val checkpointManager = new ExternalCheckpointManager(
|
||||||
|
cpParam.checkpointPath,
|
||||||
|
FileSystem.get(sc.hadoopConfiguration))
|
||||||
checkpointManager.cleanPath()
|
checkpointManager.cleanPath()
|
||||||
}
|
}
|
||||||
producedBooster
|
}
|
||||||
|
(booster, metrics)
|
||||||
} catch {
|
} catch {
|
||||||
case t: Throwable =>
|
case t: Throwable =>
|
||||||
// if the job was aborted due to an exception
|
// if the job was aborted due to an exception
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
|
|||||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
||||||
|
|
||||||
def needDeterministicRepartitioning: Boolean = {
|
def needDeterministicRepartitioning: Boolean = {
|
||||||
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
|
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.{FunSuite, Ignore}
|
||||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
|
|
||||||
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||||
|
|
||||||
private lazy val (model4, model8) = {
|
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
|
||||||
|
Map[String, Any] = {
|
||||||
|
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
|
||||||
|
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def createNewModels():
|
||||||
|
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
|
||||||
|
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
||||||
|
val (model4, model8) = {
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
val paramMap = produceParamMap(tmpPath, 2)
|
||||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
|
|
||||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||||
}
|
}
|
||||||
|
(tmpPath, model4, model8)
|
||||||
|
}
|
||||||
|
|
||||||
test("test update/load models") {
|
test("test update/load models") {
|
||||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
val (tmpPath, model4, model8) = createNewModels()
|
||||||
val manager = new CheckpointManager(sc, tmpPath)
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
manager.updateCheckpoint(model4._booster)
|
|
||||||
|
manager.updateCheckpoint(model4._booster.booster)
|
||||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "4.model")
|
assert(files.head.getPath.getName == "4.model")
|
||||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
|
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
|
||||||
|
|
||||||
manager.updateCheckpoint(model8._booster)
|
manager.updateCheckpoint(model8._booster)
|
||||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
assert(files.length == 1)
|
assert(files.length == 1)
|
||||||
assert(files.head.getPath.getName == "8.model")
|
assert(files.head.getPath.getName == "8.model")
|
||||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
|
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test cleanUpHigherVersions") {
|
test("test cleanUpHigherVersions") {
|
||||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
val (tmpPath, model4, model8) = createNewModels()
|
||||||
val manager = new CheckpointManager(sc, tmpPath)
|
|
||||||
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
manager.updateCheckpoint(model8._booster)
|
manager.updateCheckpoint(model8._booster)
|
||||||
manager.cleanUpHigherVersions(round = 8)
|
manager.cleanUpHigherVersions(8)
|
||||||
assert(new File(s"$tmpPath/8.model").exists())
|
assert(new File(s"$tmpPath/8.model").exists())
|
||||||
|
|
||||||
manager.cleanUpHigherVersions(round = 4)
|
manager.cleanUpHigherVersions(4)
|
||||||
assert(!new File(s"$tmpPath/8.model").exists())
|
assert(!new File(s"$tmpPath/8.model").exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test checkpoint rounds") {
|
test("test checkpoint rounds") {
|
||||||
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
|
import scala.collection.JavaConverters._
|
||||||
val manager = new CheckpointManager(sc, tmpPath)
|
val (tmpPath, model4, model8) = createNewModels()
|
||||||
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
|
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
|
||||||
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
|
assertResult(Seq(7))(
|
||||||
|
manager.getCheckpointRounds(0, 7).asScala)
|
||||||
|
assertResult(Seq(2, 4, 6, 7))(
|
||||||
|
manager.getCheckpointRounds(2, 7).asScala)
|
||||||
manager.updateCheckpoint(model4._booster)
|
manager.updateCheckpoint(model4._booster)
|
||||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
assertResult(Seq(4, 6, 7))(
|
||||||
|
manager.getCheckpointRounds(2, 7).asScala)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
|||||||
val testDM = new DMatrix(Classification.test.iterator)
|
val testDM = new DMatrix(Classification.test.iterator)
|
||||||
|
|
||||||
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
|
||||||
|
|
||||||
|
val paramMap = produceParamMap(tmpPath, 2)
|
||||||
|
|
||||||
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
|
||||||
val skipCleanCheckpointMap =
|
val skipCleanCheckpointMap =
|
||||||
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
|
||||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
|
|
||||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
|
||||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
|
|
||||||
skipCleanCheckpointMap
|
|
||||||
|
|
||||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
|
||||||
def error(model: Booster): Float = eval.eval(
|
|
||||||
model.predict(testDM, outPutMargin = true), testDM)
|
val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
|
||||||
|
|
||||||
|
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
|
||||||
|
|
||||||
if (skipCleanCheckpoint) {
|
if (skipCleanCheckpoint) {
|
||||||
// Check only one model is kept after training
|
// Check only one model is kept after training
|
||||||
@@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
|
|||||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||||
// Train next model based on prev model
|
// Train next model based on prev model
|
||||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||||
assert(error(tmpModel) > error(prevModel._booster))
|
assert(error(tmpModel) >= error(prevModel._booster))
|
||||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||||
assert(error(nextModel._booster) < 0.1)
|
assert(error(nextModel._booster) < 0.1)
|
||||||
} else {
|
} else {
|
||||||
@@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
" stop the application") {
|
" stop the application") {
|
||||||
val spark = ss
|
val spark = ss
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
ss.sparkContext.setLogLevel("INFO")
|
|
||||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||||
// vector,
|
// vector,
|
||||||
val testDF = Seq(
|
val testDF = Seq(
|
||||||
@@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
|
|||||||
"does not stop application") {
|
"does not stop application") {
|
||||||
val spark = ss
|
val spark = ss
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
ss.sparkContext.setLogLevel("INFO")
|
|
||||||
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
|
||||||
// vector,
|
// vector,
|
||||||
val testDF = Seq(
|
val testDF = Seq(
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
|
||||||
|
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
|
||||||
|
|||||||
@@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
|
|||||||
|
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
|
||||||
import org.apache.spark.{SparkConf, SparkContext}
|
import org.scalatest.{FunSuite, Ignore}
|
||||||
import org.scalatest.FunSuite
|
|
||||||
|
|
||||||
|
|
||||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,25 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>ml.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
</parent>
|
</parent>
|
||||||
<artifactId>xgboost4j_2.12</artifactId>
|
<artifactId>xgboost4j_2.12</artifactId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0</version>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.hadoop</groupId>
|
||||||
|
<artifactId>hadoop-hdfs</artifactId>
|
||||||
|
<version>${hadoop.version}</version>
|
||||||
|
<scope>provided</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.hadoop</groupId>
|
||||||
|
<artifactId>hadoop-common</artifactId>
|
||||||
|
<version>${hadoop.version}</version>
|
||||||
|
<scope>provided</scope>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>junit</groupId>
|
<groupId>junit</groupId>
|
||||||
<artifactId>junit</artifactId>
|
<artifactId>junit</artifactId>
|
||||||
|
|||||||
@@ -0,0 +1,117 @@
|
|||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
|
import org.apache.hadoop.fs.Path;
|
||||||
|
|
||||||
|
public class ExternalCheckpointManager {
|
||||||
|
|
||||||
|
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
||||||
|
private String modelSuffix = ".model";
|
||||||
|
private Path checkpointPath;
|
||||||
|
private FileSystem fs;
|
||||||
|
|
||||||
|
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
||||||
|
if (checkpointPath == null || checkpointPath.isEmpty()) {
|
||||||
|
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
|
||||||
|
" empty checkpoint path");
|
||||||
|
}
|
||||||
|
this.checkpointPath = new Path(checkpointPath);
|
||||||
|
this.fs = fs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getPath(int version) {
|
||||||
|
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Integer> getExistingVersions() throws IOException {
|
||||||
|
if (!fs.exists(checkpointPath)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
} else {
|
||||||
|
return Arrays.stream(fs.listStatus(checkpointPath))
|
||||||
|
.map(path -> path.getPath().getName())
|
||||||
|
.filter(fileName -> fileName.endsWith(modelSuffix))
|
||||||
|
.map(fileName -> Integer.valueOf(
|
||||||
|
fileName.substring(0, fileName.length() - modelSuffix.length())))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void cleanPath() throws IOException {
|
||||||
|
fs.delete(checkpointPath, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
||||||
|
List<Integer> versions = getExistingVersions();
|
||||||
|
if (versions.size() > 0) {
|
||||||
|
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
||||||
|
String checkpointPath = getPath(latestVersion);
|
||||||
|
InputStream in = fs.open(new Path(checkpointPath));
|
||||||
|
logger.info("loaded checkpoint from " + checkpointPath);
|
||||||
|
Booster booster = XGBoost.loadModel(in);
|
||||||
|
booster.setVersion(latestVersion);
|
||||||
|
return booster;
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
||||||
|
List<String> prevModelPaths = getExistingVersions().stream()
|
||||||
|
.map(this::getPath).collect(Collectors.toList());
|
||||||
|
String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
||||||
|
String tempPath = eventualPath + "-" + UUID.randomUUID();
|
||||||
|
try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
||||||
|
boosterToCheckpoint.saveModel(out);
|
||||||
|
fs.rename(new Path(tempPath), new Path(eventualPath));
|
||||||
|
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
||||||
|
prevModelPaths.stream().forEach(path -> {
|
||||||
|
try {
|
||||||
|
fs.delete(new Path(path), true);
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("failed to delete outdated checkpoint at " + path, e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void cleanUpHigherVersions(int currentRound) throws IOException {
|
||||||
|
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
||||||
|
try {
|
||||||
|
fs.delete(new Path(getPath(v)), true);
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("failed to clean checkpoint from other training instance", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
||||||
|
throws IOException {
|
||||||
|
if (checkpointInterval > 0) {
|
||||||
|
List<Integer> prevRounds =
|
||||||
|
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
||||||
|
prevRounds.add(0);
|
||||||
|
int firstCheckpointRound = prevRounds.stream()
|
||||||
|
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
||||||
|
List<Integer> arr = new ArrayList<>();
|
||||||
|
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
||||||
|
arr.add(i);
|
||||||
|
}
|
||||||
|
arr.add(numOfRounds);
|
||||||
|
return arr;
|
||||||
|
} else if (checkpointInterval <= 0) {
|
||||||
|
List<Integer> l = new ArrayList<Integer>();
|
||||||
|
l.add(numOfRounds);
|
||||||
|
return l;
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,12 +15,16 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
|
import org.apache.hadoop.fs.Path;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* trainer for xgboost
|
* trainer for xgboost
|
||||||
@@ -108,35 +112,34 @@ public class XGBoost {
|
|||||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private static void saveCheckpoint(
|
||||||
* Train a booster given parameters.
|
Booster booster,
|
||||||
*
|
int iter,
|
||||||
* @param dtrain Data to be trained.
|
Set<Integer> checkpointIterations,
|
||||||
* @param params Parameters.
|
ExternalCheckpointManager ecm) throws XGBoostError {
|
||||||
* @param round Number of boosting iterations.
|
try {
|
||||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
if (checkpointIterations.contains(iter)) {
|
||||||
* performance on the validation set.
|
ecm.updateCheckpoint(booster);
|
||||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
}
|
||||||
* iteration
|
} catch (Exception e) {
|
||||||
* @param earlyStoppingRounds if non-zero, training would be stopped
|
logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
|
||||||
* after a specified number of consecutive
|
throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
|
||||||
* goes to the unexpected direction in any evaluation metric.
|
}
|
||||||
* @param obj customized objective
|
}
|
||||||
* @param eval customized evaluation
|
|
||||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
public static Booster trainAndSaveCheckpoint(
|
||||||
* @return The trained booster.
|
|
||||||
*/
|
|
||||||
public static Booster train(
|
|
||||||
DMatrix dtrain,
|
DMatrix dtrain,
|
||||||
Map<String, Object> params,
|
Map<String, Object> params,
|
||||||
int round,
|
int numRounds,
|
||||||
Map<String, DMatrix> watches,
|
Map<String, DMatrix> watches,
|
||||||
float[][] metrics,
|
float[][] metrics,
|
||||||
IObjective obj,
|
IObjective obj,
|
||||||
IEvaluation eval,
|
IEvaluation eval,
|
||||||
int earlyStoppingRounds,
|
int earlyStoppingRounds,
|
||||||
Booster booster) throws XGBoostError {
|
Booster booster,
|
||||||
|
int checkpointInterval,
|
||||||
|
String checkpointPath,
|
||||||
|
FileSystem fs) throws XGBoostError, IOException {
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
String[] evalNames;
|
String[] evalNames;
|
||||||
DMatrix[] evalMats;
|
DMatrix[] evalMats;
|
||||||
@@ -144,6 +147,11 @@ public class XGBoost {
|
|||||||
int bestIteration;
|
int bestIteration;
|
||||||
List<String> names = new ArrayList<String>();
|
List<String> names = new ArrayList<String>();
|
||||||
List<DMatrix> mats = new ArrayList<DMatrix>();
|
List<DMatrix> mats = new ArrayList<DMatrix>();
|
||||||
|
Set<Integer> checkpointIterations = new HashSet<>();
|
||||||
|
ExternalCheckpointManager ecm = null;
|
||||||
|
if (checkpointPath != null) {
|
||||||
|
ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
||||||
|
}
|
||||||
|
|
||||||
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
|
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
|
||||||
names.add(evalEntry.getKey());
|
names.add(evalEntry.getKey());
|
||||||
@@ -158,7 +166,7 @@ public class XGBoost {
|
|||||||
bestScore = Float.MAX_VALUE;
|
bestScore = Float.MAX_VALUE;
|
||||||
}
|
}
|
||||||
bestIteration = 0;
|
bestIteration = 0;
|
||||||
metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
||||||
|
|
||||||
//collect all data matrixs
|
//collect all data matrixs
|
||||||
DMatrix[] allMats;
|
DMatrix[] allMats;
|
||||||
@@ -181,14 +189,19 @@ public class XGBoost {
|
|||||||
booster.setParams(params);
|
booster.setParams(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
//begin to train
|
if (ecm != null) {
|
||||||
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
|
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
||||||
|
}
|
||||||
|
|
||||||
|
// begin to train
|
||||||
|
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
||||||
if (booster.getVersion() % 2 == 0) {
|
if (booster.getVersion() % 2 == 0) {
|
||||||
if (obj != null) {
|
if (obj != null) {
|
||||||
booster.update(dtrain, obj);
|
booster.update(dtrain, obj);
|
||||||
} else {
|
} else {
|
||||||
booster.update(dtrain, iter);
|
booster.update(dtrain, iter);
|
||||||
}
|
}
|
||||||
|
saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
||||||
booster.saveRabitCheckpoint();
|
booster.saveRabitCheckpoint();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,6 +252,44 @@ public class XGBoost {
|
|||||||
return booster;
|
return booster;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train a booster given parameters.
|
||||||
|
*
|
||||||
|
* @param dtrain Data to be trained.
|
||||||
|
* @param params Parameters.
|
||||||
|
* @param round Number of boosting iterations.
|
||||||
|
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||||
|
* performance on the validation set.
|
||||||
|
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||||
|
* iteration
|
||||||
|
* @param earlyStoppingRounds if non-zero, training would be stopped
|
||||||
|
* after a specified number of consecutive
|
||||||
|
* goes to the unexpected direction in any evaluation metric.
|
||||||
|
* @param obj customized objective
|
||||||
|
* @param eval customized evaluation
|
||||||
|
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||||
|
* @return The trained booster.
|
||||||
|
*/
|
||||||
|
public static Booster train(
|
||||||
|
DMatrix dtrain,
|
||||||
|
Map<String, Object> params,
|
||||||
|
int round,
|
||||||
|
Map<String, DMatrix> watches,
|
||||||
|
float[][] metrics,
|
||||||
|
IObjective obj,
|
||||||
|
IEvaluation eval,
|
||||||
|
int earlyStoppingRounds,
|
||||||
|
Booster booster) throws XGBoostError {
|
||||||
|
try {
|
||||||
|
return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
|
||||||
|
earlyStoppingRounds, booster,
|
||||||
|
-1, null, null);
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("training failed in xgboost4j", e);
|
||||||
|
throw new XGBoostError("training failed in xgboost4j ", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static Integer tryGetIntFromObject(Object o) {
|
private static Integer tryGetIntFromObject(Object o) {
|
||||||
if (o instanceof Integer) {
|
if (o instanceof Integer) {
|
||||||
return (int)o;
|
return (int)o;
|
||||||
|
|||||||
@@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
|
|||||||
public XGBoostError(String message) {
|
public XGBoostError(String message) {
|
||||||
super(message);
|
super(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public XGBoostError(String message, Throwable cause) {
|
||||||
|
super(message, cause);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
|
||||||
|
import org.apache.hadoop.fs.FileSystem
|
||||||
|
|
||||||
|
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
|
||||||
|
extends JavaECM(checkpointPath, fs) {
|
||||||
|
|
||||||
|
def updateCheckpoint(booster: Booster): Unit = {
|
||||||
|
super.updateCheckpoint(booster.booster)
|
||||||
|
}
|
||||||
|
|
||||||
|
def loadCheckpointAsScalaBooster(): Booster = {
|
||||||
|
val loadedBooster = super.loadCheckpointAsBooster()
|
||||||
|
if (loadedBooster == null) {
|
||||||
|
null
|
||||||
|
} else {
|
||||||
|
new Booster(loadedBooster)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
|
|||||||
|
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
import org.apache.hadoop.conf.Configuration
|
||||||
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* XGBoost Scala Training function.
|
* XGBoost Scala Training function.
|
||||||
*/
|
*/
|
||||||
object XGBoost {
|
object XGBoost {
|
||||||
|
|
||||||
|
private[scala] def trainAndSaveCheckpoint(
|
||||||
|
dtrain: DMatrix,
|
||||||
|
params: Map[String, Any],
|
||||||
|
numRounds: Int,
|
||||||
|
watches: Map[String, DMatrix] = Map(),
|
||||||
|
metrics: Array[Array[Float]] = null,
|
||||||
|
obj: ObjectiveTrait = null,
|
||||||
|
eval: EvalTrait = null,
|
||||||
|
earlyStoppingRound: Int = 0,
|
||||||
|
prevBooster: Booster,
|
||||||
|
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
|
||||||
|
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||||
|
val jBooster = if (prevBooster == null) {
|
||||||
|
null
|
||||||
|
} else {
|
||||||
|
prevBooster.booster
|
||||||
|
}
|
||||||
|
val xgboostInJava = checkpointParams.
|
||||||
|
map(cp => {
|
||||||
|
JXGBoost.trainAndSaveCheckpoint(
|
||||||
|
dtrain.jDMatrix,
|
||||||
|
// we have to filter null value for customized obj and eval
|
||||||
|
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||||
|
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
|
||||||
|
cp.checkpointInterval,
|
||||||
|
cp.checkpointPath,
|
||||||
|
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
|
||||||
|
}).
|
||||||
|
getOrElse(
|
||||||
|
JXGBoost.train(
|
||||||
|
dtrain.jDMatrix,
|
||||||
|
// we have to filter null value for customized obj and eval
|
||||||
|
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||||
|
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||||
|
)
|
||||||
|
if (prevBooster == null) {
|
||||||
|
new Booster(xgboostInJava)
|
||||||
|
} else {
|
||||||
|
// Avoid creating a new SBooster with the same JBooster
|
||||||
|
prevBooster
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train a booster given parameters.
|
* Train a booster given parameters.
|
||||||
*
|
*
|
||||||
@@ -55,23 +101,8 @@ object XGBoost {
|
|||||||
eval: EvalTrait = null,
|
eval: EvalTrait = null,
|
||||||
earlyStoppingRound: Int = 0,
|
earlyStoppingRound: Int = 0,
|
||||||
booster: Booster = null): Booster = {
|
booster: Booster = null): Booster = {
|
||||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
|
||||||
val jBooster = if (booster == null) {
|
booster, None)
|
||||||
null
|
|
||||||
} else {
|
|
||||||
booster.booster
|
|
||||||
}
|
|
||||||
val xgboostInJava = JXGBoost.train(
|
|
||||||
dtrain.jDMatrix,
|
|
||||||
// we have to filter null value for customized obj and eval
|
|
||||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
|
||||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
|
||||||
if (booster == null) {
|
|
||||||
new Booster(xgboostInJava)
|
|
||||||
} else {
|
|
||||||
// Avoid creating a new SBooster with the same JBooster
|
|
||||||
booster
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -126,3 +157,41 @@ object XGBoost {
|
|||||||
new Booster(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private[scala] case class ExternalCheckpointParams(
|
||||||
|
checkpointInterval: Int,
|
||||||
|
checkpointPath: String,
|
||||||
|
skipCleanCheckpoint: Boolean)
|
||||||
|
|
||||||
|
private[scala] object ExternalCheckpointParams {
|
||||||
|
|
||||||
|
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
|
||||||
|
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||||
|
case None | Some(null) | Some("") => null
|
||||||
|
case Some(path: String) => path
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||||
|
s" an instance of String, but current value is ${params("checkpoint_path")}")
|
||||||
|
}
|
||||||
|
|
||||||
|
val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
||||||
|
case None => 0
|
||||||
|
case Some(freq: Int) => freq
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
||||||
|
" an instance of Int.")
|
||||||
|
}
|
||||||
|
|
||||||
|
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
||||||
|
case None => false
|
||||||
|
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
||||||
|
" an instance of Boolean")
|
||||||
|
}
|
||||||
|
if (checkpointPath == null || checkpointInterval == 0) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ setup(name='xgboost',
|
|||||||
'Programming Language :: Python :: 3',
|
'Programming Language :: Python :: 3',
|
||||||
'Programming Language :: Python :: 3.5',
|
'Programming Language :: Python :: 3.5',
|
||||||
'Programming Language :: Python :: 3.6',
|
'Programming Language :: Python :: 3.6',
|
||||||
'Programming Language :: Python :: 3.7'],
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3.8'],
|
||||||
python_requires='>=3.5',
|
python_requires='>=3.5',
|
||||||
url='https://github.com/dmlc/xgboost')
|
url='https://github.com/dmlc/xgboost')
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ setup(name='xgboost',
|
|||||||
'Programming Language :: Python :: 3',
|
'Programming Language :: Python :: 3',
|
||||||
'Programming Language :: Python :: 3.5',
|
'Programming Language :: Python :: 3.5',
|
||||||
'Programming Language :: Python :: 3.6',
|
'Programming Language :: Python :: 3.6',
|
||||||
'Programming Language :: Python :: 3.7'],
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3.8'],
|
||||||
python_requires='>=3.5',
|
python_requires='>=3.5',
|
||||||
url='https://github.com/dmlc/xgboost')
|
url='https://github.com/dmlc/xgboost')
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
1.0.0-SNAPSHOT
|
1.0.1
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
from .core import DMatrix, Booster
|
from .core import DMatrix, Booster
|
||||||
from .training import train, cv
|
from .training import train, cv
|
||||||
@@ -19,6 +21,12 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if sys.version_info[:2] == (3, 5):
|
||||||
|
warnings.warn(
|
||||||
|
'Python 3.5 support is deprecated; XGBoost will require Python 3.6+ in the near future. ' +
|
||||||
|
'Consider upgrading to Python 3.6+.',
|
||||||
|
FutureWarning)
|
||||||
|
|
||||||
VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
|
VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
|
||||||
with open(VERSION_FILE) as f:
|
with open(VERSION_FILE) as f:
|
||||||
__version__ = f.read().strip()
|
__version__ = f.read().strip()
|
||||||
|
|||||||
@@ -600,6 +600,7 @@ class DaskXGBRegressor(DaskScikitLearnBase):
|
|||||||
results = train(self.client, params, dtrain,
|
results = train(self.client, params, dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
evals=evals)
|
evals=evals)
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self._Booster = results['booster']
|
self._Booster = results['booster']
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.evals_result_ = results['history']
|
self.evals_result_ = results['history']
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ Parameters
|
|||||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||||
['estimators', 'model', 'objective'])
|
['estimators', 'model', 'objective'])
|
||||||
class XGBModel(XGBModelBase):
|
class XGBModel(XGBModelBase):
|
||||||
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name, missing-docstring
|
# pylint: disable=too-many-arguments, too-many-instance-attributes, missing-docstring
|
||||||
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
|
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
|
||||||
verbosity=None, objective=None, booster=None,
|
verbosity=None, objective=None, booster=None,
|
||||||
tree_method=None, n_jobs=None, gamma=None,
|
tree_method=None, n_jobs=None, gamma=None,
|
||||||
@@ -210,7 +210,8 @@ class XGBModel(XGBModelBase):
|
|||||||
scale_pos_weight=None, base_score=None, random_state=None,
|
scale_pos_weight=None, base_score=None, random_state=None,
|
||||||
missing=None, num_parallel_tree=None,
|
missing=None, num_parallel_tree=None,
|
||||||
monotone_constraints=None, interaction_constraints=None,
|
monotone_constraints=None, interaction_constraints=None,
|
||||||
importance_type="gain", gpu_id=None, **kwargs):
|
importance_type="gain", gpu_id=None,
|
||||||
|
validate_parameters=False, **kwargs):
|
||||||
if not SKLEARN_INSTALLED:
|
if not SKLEARN_INSTALLED:
|
||||||
raise XGBoostError(
|
raise XGBoostError(
|
||||||
'sklearn needs to be installed in order to use this module')
|
'sklearn needs to be installed in order to use this module')
|
||||||
@@ -243,6 +244,10 @@ class XGBModel(XGBModelBase):
|
|||||||
self.interaction_constraints = interaction_constraints
|
self.interaction_constraints = interaction_constraints
|
||||||
self.importance_type = importance_type
|
self.importance_type = importance_type
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
|
# Parameter validation is not working with Scikit-Learn interface, as
|
||||||
|
# it passes all paraemters into XGBoost core, whether they are used or
|
||||||
|
# not.
|
||||||
|
self.validate_parameters = validate_parameters
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
# backward compatibility code
|
# backward compatibility code
|
||||||
@@ -314,11 +319,35 @@ class XGBModel(XGBModelBase):
|
|||||||
if isinstance(params['random_state'], np.random.RandomState):
|
if isinstance(params['random_state'], np.random.RandomState):
|
||||||
params['random_state'] = params['random_state'].randint(
|
params['random_state'] = params['random_state'].randint(
|
||||||
np.iinfo(np.int32).max)
|
np.iinfo(np.int32).max)
|
||||||
# Parameter validation is not working with Scikit-Learn interface, as
|
|
||||||
# it passes all paraemters into XGBoost core, whether they are used or
|
def parse_parameter(value):
|
||||||
# not.
|
for t in (int, float):
|
||||||
if 'validate_parameters' not in params.keys():
|
try:
|
||||||
params['validate_parameters'] = False
|
ret = t(value)
|
||||||
|
return ret
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get internal parameter values
|
||||||
|
try:
|
||||||
|
config = json.loads(self.get_booster().save_config())
|
||||||
|
stack = [config]
|
||||||
|
internal = {}
|
||||||
|
while stack:
|
||||||
|
obj = stack.pop()
|
||||||
|
for k, v in obj.items():
|
||||||
|
if k.endswith('_param'):
|
||||||
|
for p_k, p_v in v.items():
|
||||||
|
internal[p_k] = p_v
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
stack.append(v)
|
||||||
|
|
||||||
|
for k, v in internal.items():
|
||||||
|
if k in params.keys() and params[k] is None:
|
||||||
|
params[k] = parse_parameter(v)
|
||||||
|
except XGBoostError:
|
||||||
|
pass
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
@@ -405,8 +434,8 @@ class XGBModel(XGBModelBase):
|
|||||||
self.classes_ = np.array(v)
|
self.classes_ = np.array(v)
|
||||||
continue
|
continue
|
||||||
if k == 'type' and type(self).__name__ != v:
|
if k == 'type' and type(self).__name__ != v:
|
||||||
msg = f'Current model type: {type(self).__name__}, ' + \
|
msg = 'Current model type: {}, '.format(type(self).__name__) + \
|
||||||
f'type of model in file: {v}'
|
'type of model in file: {}'.format(v)
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
if k == 'type':
|
if k == 'type':
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
* Copyright (c) by Contributors 2019
|
* Copyright (c) by Contributors 2019
|
||||||
*/
|
*/
|
||||||
#include <cctype>
|
#include <cctype>
|
||||||
|
#include <locale>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@@ -24,7 +25,7 @@ void JsonWriter::Visit(JsonArray const* arr) {
|
|||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
auto const& value = vec[i];
|
auto const& value = vec[i];
|
||||||
this->Save(value);
|
this->Save(value);
|
||||||
if (i != size-1) { Write(", "); }
|
if (i != size-1) { Write(","); }
|
||||||
}
|
}
|
||||||
this->Write("]");
|
this->Write("]");
|
||||||
}
|
}
|
||||||
@@ -38,7 +39,7 @@ void JsonWriter::Visit(JsonObject const* obj) {
|
|||||||
size_t size = obj->getObject().size();
|
size_t size = obj->getObject().size();
|
||||||
|
|
||||||
for (auto& value : obj->getObject()) {
|
for (auto& value : obj->getObject()) {
|
||||||
this->Write("\"" + value.first + "\": ");
|
this->Write("\"" + value.first + "\":");
|
||||||
this->Save(value.second);
|
this->Save(value.second);
|
||||||
|
|
||||||
if (i != size-1) {
|
if (i != size-1) {
|
||||||
@@ -692,47 +693,23 @@ Json JsonReader::ParseBoolean() {
|
|||||||
return Json{JsonBoolean{result}};
|
return Json{JsonBoolean{result}};
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is an ad-hoc solution for writing numeric value in standard way. We need to add
|
|
||||||
// a locale independent way of writing stream like `std::{from, to}_chars' from C++-17.
|
|
||||||
// FIXME(trivialfis): Remove this.
|
|
||||||
class GlobalCLocale {
|
|
||||||
std::locale ori_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
GlobalCLocale() : ori_{std::locale()} {
|
|
||||||
std::string const name {"C"};
|
|
||||||
try {
|
|
||||||
std::locale::global(std::locale(name.c_str()));
|
|
||||||
} catch (std::runtime_error const& e) {
|
|
||||||
LOG(FATAL) << "Failed to set locale: " << name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
~GlobalCLocale() {
|
|
||||||
std::locale::global(ori_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Json Json::Load(StringView str) {
|
Json Json::Load(StringView str) {
|
||||||
GlobalCLocale guard;
|
|
||||||
JsonReader reader(str);
|
JsonReader reader(str);
|
||||||
Json json{reader.Load()};
|
Json json{reader.Load()};
|
||||||
return json;
|
return json;
|
||||||
}
|
}
|
||||||
|
|
||||||
Json Json::Load(JsonReader* reader) {
|
Json Json::Load(JsonReader* reader) {
|
||||||
GlobalCLocale guard;
|
|
||||||
Json json{reader->Load()};
|
Json json{reader->Load()};
|
||||||
return json;
|
return json;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Json::Dump(Json json, std::ostream *stream, bool pretty) {
|
void Json::Dump(Json json, std::ostream *stream, bool pretty) {
|
||||||
GlobalCLocale guard;
|
|
||||||
JsonWriter writer(stream, pretty);
|
JsonWriter writer(stream, pretty);
|
||||||
writer.Save(json);
|
writer.Save(json);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Json::Dump(Json json, std::string* str, bool pretty) {
|
void Json::Dump(Json json, std::string* str, bool pretty) {
|
||||||
GlobalCLocale guard;
|
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
JsonWriter writer(&ss, pretty);
|
JsonWriter writer(&ss, pretty);
|
||||||
writer.Save(json);
|
writer.Save(json);
|
||||||
|
|||||||
@@ -15,12 +15,23 @@
|
|||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
|
|
||||||
|
#if defined(XGBOOST_STRICT_R_MODE)
|
||||||
|
#define OBSERVER_PRINT LOG(INFO)
|
||||||
|
#define OBSERVER_ENDL ""
|
||||||
|
#define OBSERVER_NEWLINE ""
|
||||||
|
#else
|
||||||
|
#define OBSERVER_PRINT std::cout
|
||||||
|
#define OBSERVER_ENDL std::endl
|
||||||
|
#define OBSERVER_NEWLINE "\n"
|
||||||
|
#endif // defined(XGBOOST_STRICT_R_MODE)
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/*\brief An observer for logging internal data structures.
|
/*\brief An observer for logging internal data structures.
|
||||||
*
|
*
|
||||||
* This class is designed to be `diff` tool friendly, which means it uses plain
|
* This class is designed to be `diff` tool friendly, which means it uses plain
|
||||||
* `std::cout` for printing to avoid the time information emitted by `LOG(DEBUG)` or
|
* `std::cout` for printing to avoid the time information emitted by `LOG(DEBUG)` or
|
||||||
* similiar facilities.
|
* similiar facilities. Exception: use `LOG(INFO)` for the R package, to comply
|
||||||
|
* with CRAN policy.
|
||||||
*/
|
*/
|
||||||
class TrainingObserver {
|
class TrainingObserver {
|
||||||
#if defined(XGBOOST_USE_DEBUG_OUTPUT)
|
#if defined(XGBOOST_USE_DEBUG_OUTPUT)
|
||||||
@@ -32,17 +43,17 @@ class TrainingObserver {
|
|||||||
public:
|
public:
|
||||||
void Update(int32_t iter) const {
|
void Update(int32_t iter) const {
|
||||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
std::cout << "Iter: " << iter << std::endl;
|
OBSERVER_PRINT << "Iter: " << iter << OBSERVER_ENDL;
|
||||||
}
|
}
|
||||||
/*\brief Observe tree. */
|
/*\brief Observe tree. */
|
||||||
void Observe(RegTree const& tree) {
|
void Observe(RegTree const& tree) {
|
||||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
std::cout << "Tree:" << std::endl;
|
OBSERVER_PRINT << "Tree:" << OBSERVER_ENDL;
|
||||||
Json j_tree {Object()};
|
Json j_tree {Object()};
|
||||||
tree.SaveModel(&j_tree);
|
tree.SaveModel(&j_tree);
|
||||||
std::string str;
|
std::string str;
|
||||||
Json::Dump(j_tree, &str, true);
|
Json::Dump(j_tree, &str, true);
|
||||||
std::cout << str << std::endl;
|
OBSERVER_PRINT << str << OBSERVER_ENDL;
|
||||||
}
|
}
|
||||||
/*\brief Observe tree. */
|
/*\brief Observe tree. */
|
||||||
void Observe(RegTree const* p_tree) {
|
void Observe(RegTree const* p_tree) {
|
||||||
@@ -54,15 +65,15 @@ class TrainingObserver {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
void Observe(std::vector<T> const& h_vec, std::string name) const {
|
void Observe(std::vector<T> const& h_vec, std::string name) const {
|
||||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
std::cout << "Procedure: " << name << std::endl;
|
OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL;
|
||||||
|
|
||||||
for (size_t i = 0; i < h_vec.size(); ++i) {
|
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||||
std::cout << h_vec[i] << ", ";
|
OBSERVER_PRINT << h_vec[i] << ", ";
|
||||||
if (i % 8 == 0) {
|
if (i % 8 == 0) {
|
||||||
std::cout << '\n';
|
OBSERVER_PRINT << OBSERVER_NEWLINE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
OBSERVER_PRINT << OBSERVER_ENDL;
|
||||||
}
|
}
|
||||||
/*\brief Observe data hosted by `HostDeviceVector'. */
|
/*\brief Observe data hosted by `HostDeviceVector'. */
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -85,16 +96,16 @@ class TrainingObserver {
|
|||||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
|
||||||
Json obj {toJson(p)};
|
Json obj {toJson(p)};
|
||||||
std::cout << "Parameter: " << name << ":\n" << obj << std::endl;
|
OBSERVER_PRINT << "Parameter: " << name << ":\n" << obj << OBSERVER_ENDL;
|
||||||
}
|
}
|
||||||
/*\brief Observe parameters provided by users. */
|
/*\brief Observe parameters provided by users. */
|
||||||
void Observe(Args const& args) const {
|
void Observe(Args const& args) const {
|
||||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
|
||||||
for (auto kv : args) {
|
for (auto kv : args) {
|
||||||
std::cout << kv.first << ": " << kv.second << "\n";
|
OBSERVER_PRINT << kv.first << ": " << kv.second << OBSERVER_NEWLINE;
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
OBSERVER_PRINT << OBSERVER_ENDL;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*\brief Get a global instance. */
|
/*\brief Get a global instance. */
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ void Monitor::PrintStatistics(StatMap const& statistics) const {
|
|||||||
"Timer for " << kv.first << " did not get stopped properly.";
|
"Timer for " << kv.first << " did not get stopped properly.";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::cout << kv.first << ": " << static_cast<double>(kv.second.second) / 1e+6
|
LOG(CONSOLE) << kv.first << ": " << static_cast<double>(kv.second.second) / 1e+6
|
||||||
<< "s, " << kv.second.first << " calls @ "
|
<< "s, " << kv.second.first << " calls @ "
|
||||||
<< kv.second.second
|
<< kv.second.second
|
||||||
<< "us" << std::endl;
|
<< "us" << std::endl;
|
||||||
@@ -107,10 +107,9 @@ void Monitor::Print() const {
|
|||||||
if (rabit::GetRank() == 0) {
|
if (rabit::GetRank() == 0) {
|
||||||
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
||||||
for (size_t i = 0; i < world.size(); ++i) {
|
for (size_t i = 0; i < world.size(); ++i) {
|
||||||
std::cout << "From rank: " << i << ": " << std::endl;
|
LOG(CONSOLE) << "From rank: " << i << ": " << std::endl;
|
||||||
auto const& statistic = world[i];
|
auto const& statistic = world[i];
|
||||||
this->PrintStatistics(statistic);
|
this->PrintStatistics(statistic);
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -123,7 +122,6 @@ void Monitor::Print() const {
|
|||||||
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
||||||
this->PrintStatistics(stat_map);
|
this->PrintStatistics(stat_map);
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
195
src/learner.cc
195
src/learner.cc
@@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2014-2019 by Contributors
|
* Copyright 2014-2020 by Contributors
|
||||||
* \file learner.cc
|
* \file learner.cc
|
||||||
* \brief Implementation of learning algorithm.
|
* \brief Implementation of learning algorithm.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@@ -67,19 +67,26 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
|||||||
/* \brief global bias */
|
/* \brief global bias */
|
||||||
bst_float base_score;
|
bst_float base_score;
|
||||||
/* \brief number of features */
|
/* \brief number of features */
|
||||||
unsigned num_feature;
|
uint32_t num_feature;
|
||||||
/* \brief number of classes, if it is multi-class classification */
|
/* \brief number of classes, if it is multi-class classification */
|
||||||
int num_class;
|
int32_t num_class;
|
||||||
/*! \brief Model contain additional properties */
|
/*! \brief Model contain additional properties */
|
||||||
int contain_extra_attrs;
|
int32_t contain_extra_attrs;
|
||||||
/*! \brief Model contain eval metrics */
|
/*! \brief Model contain eval metrics */
|
||||||
int contain_eval_metrics;
|
int32_t contain_eval_metrics;
|
||||||
|
/*! \brief the version of XGBoost. */
|
||||||
|
uint32_t major_version;
|
||||||
|
uint32_t minor_version;
|
||||||
/*! \brief reserved field */
|
/*! \brief reserved field */
|
||||||
int reserved[29];
|
int reserved[27];
|
||||||
/*! \brief constructor */
|
/*! \brief constructor */
|
||||||
LearnerModelParamLegacy() {
|
LearnerModelParamLegacy() {
|
||||||
std::memset(this, 0, sizeof(LearnerModelParamLegacy));
|
std::memset(this, 0, sizeof(LearnerModelParamLegacy));
|
||||||
base_score = 0.5f;
|
base_score = 0.5f;
|
||||||
|
major_version = std::get<0>(Version::Self());
|
||||||
|
minor_version = std::get<1>(Version::Self());
|
||||||
|
static_assert(sizeof(LearnerModelParamLegacy) == 136,
|
||||||
|
"Do not change the size of this struct, as it will break binary IO.");
|
||||||
}
|
}
|
||||||
// Skip other legacy fields.
|
// Skip other legacy fields.
|
||||||
Json ToJson() const {
|
Json ToJson() const {
|
||||||
@@ -118,7 +125,8 @@ LearnerModelParam::LearnerModelParam(
|
|||||||
: base_score{base_margin}, num_feature{user_param.num_feature},
|
: base_score{base_margin}, num_feature{user_param.num_feature},
|
||||||
num_output_group{user_param.num_class == 0
|
num_output_group{user_param.num_class == 0
|
||||||
? 1
|
? 1
|
||||||
: static_cast<uint32_t>(user_param.num_class)} {}
|
: static_cast<uint32_t>(user_param.num_class)}
|
||||||
|
{}
|
||||||
|
|
||||||
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||||
// data split mode, can be row, col, or none.
|
// data split mode, can be row, col, or none.
|
||||||
@@ -140,7 +148,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
|||||||
.describe("Data split mode for distributed training.");
|
.describe("Data split mode for distributed training.");
|
||||||
DMLC_DECLARE_FIELD(disable_default_eval_metric)
|
DMLC_DECLARE_FIELD(disable_default_eval_metric)
|
||||||
.set_default(0)
|
.set_default(0)
|
||||||
.describe("flag to disable default metric. Set to >0 to disable");
|
.describe("Flag to disable default metric. Set to >0 to disable");
|
||||||
DMLC_DECLARE_FIELD(booster)
|
DMLC_DECLARE_FIELD(booster)
|
||||||
.set_default("gbtree")
|
.set_default("gbtree")
|
||||||
.describe("Gradient booster used for training.");
|
.describe("Gradient booster used for training.");
|
||||||
@@ -200,6 +208,7 @@ class LearnerImpl : public Learner {
|
|||||||
Args args = {cfg_.cbegin(), cfg_.cend()};
|
Args args = {cfg_.cbegin(), cfg_.cend()};
|
||||||
|
|
||||||
tparam_.UpdateAllowUnknown(args);
|
tparam_.UpdateAllowUnknown(args);
|
||||||
|
auto mparam_backup = mparam_;
|
||||||
mparam_.UpdateAllowUnknown(args);
|
mparam_.UpdateAllowUnknown(args);
|
||||||
generic_parameters_.UpdateAllowUnknown(args);
|
generic_parameters_.UpdateAllowUnknown(args);
|
||||||
generic_parameters_.CheckDeprecated();
|
generic_parameters_.CheckDeprecated();
|
||||||
@@ -217,17 +226,33 @@ class LearnerImpl : public Learner {
|
|||||||
|
|
||||||
// set seed only before the model is initialized
|
// set seed only before the model is initialized
|
||||||
common::GlobalRandom().seed(generic_parameters_.seed);
|
common::GlobalRandom().seed(generic_parameters_.seed);
|
||||||
|
|
||||||
// must precede configure gbm since num_features is required for gbm
|
// must precede configure gbm since num_features is required for gbm
|
||||||
this->ConfigureNumFeatures();
|
this->ConfigureNumFeatures();
|
||||||
args = {cfg_.cbegin(), cfg_.cend()}; // renew
|
args = {cfg_.cbegin(), cfg_.cend()}; // renew
|
||||||
this->ConfigureObjective(old_tparam, &args);
|
this->ConfigureObjective(old_tparam, &args);
|
||||||
this->ConfigureGBM(old_tparam, args);
|
|
||||||
this->ConfigureMetrics(args);
|
|
||||||
|
|
||||||
generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU());
|
|
||||||
|
|
||||||
|
// Before 1.0.0, we save `base_score` into binary as a transformed value by objective.
|
||||||
|
// After 1.0.0 we save the value provided by user and keep it immutable instead. To
|
||||||
|
// keep the stability, we initialize it in binary LoadModel instead of configuration.
|
||||||
|
// Under what condition should we omit the transformation:
|
||||||
|
//
|
||||||
|
// - base_score is loaded from old binary model.
|
||||||
|
//
|
||||||
|
// What are the other possible conditions:
|
||||||
|
//
|
||||||
|
// - model loaded from new binary or JSON.
|
||||||
|
// - model is created from scratch.
|
||||||
|
// - model is configured second time due to change of parameter
|
||||||
|
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
|
||||||
learner_model_param_ = LearnerModelParam(mparam_,
|
learner_model_param_ = LearnerModelParam(mparam_,
|
||||||
obj_->ProbToMargin(mparam_.base_score));
|
obj_->ProbToMargin(mparam_.base_score));
|
||||||
|
}
|
||||||
|
|
||||||
|
this->ConfigureGBM(old_tparam, args);
|
||||||
|
generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU());
|
||||||
|
|
||||||
|
this->ConfigureMetrics(args);
|
||||||
|
|
||||||
this->need_configuration_ = false;
|
this->need_configuration_ = false;
|
||||||
if (generic_parameters_.validate_parameters) {
|
if (generic_parameters_.validate_parameters) {
|
||||||
@@ -269,10 +294,7 @@ class LearnerImpl : public Learner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto learner_model_param = mparam_.ToJson();
|
|
||||||
for (auto const& kv : get<Object>(learner_model_param)) {
|
|
||||||
keys.emplace_back(kv.first);
|
|
||||||
}
|
|
||||||
keys.emplace_back(kEvalMetric);
|
keys.emplace_back(kEvalMetric);
|
||||||
keys.emplace_back("verbosity");
|
keys.emplace_back("verbosity");
|
||||||
keys.emplace_back("num_output_group");
|
keys.emplace_back("num_output_group");
|
||||||
@@ -340,9 +362,6 @@ class LearnerImpl : public Learner {
|
|||||||
cache_));
|
cache_));
|
||||||
gbm_->LoadModel(gradient_booster);
|
gbm_->LoadModel(gradient_booster);
|
||||||
|
|
||||||
learner_model_param_ = LearnerModelParam(mparam_,
|
|
||||||
obj_->ProbToMargin(mparam_.base_score));
|
|
||||||
|
|
||||||
auto const& j_attributes = get<Object const>(learner.at("attributes"));
|
auto const& j_attributes = get<Object const>(learner.at("attributes"));
|
||||||
attributes_.clear();
|
attributes_.clear();
|
||||||
for (auto const& kv : j_attributes) {
|
for (auto const& kv : j_attributes) {
|
||||||
@@ -425,6 +444,7 @@ class LearnerImpl : public Learner {
|
|||||||
auto& learner_parameters = out["learner"];
|
auto& learner_parameters = out["learner"];
|
||||||
|
|
||||||
learner_parameters["learner_train_param"] = toJson(tparam_);
|
learner_parameters["learner_train_param"] = toJson(tparam_);
|
||||||
|
learner_parameters["learner_model_param"] = mparam_.ToJson();
|
||||||
learner_parameters["gradient_booster"] = Object();
|
learner_parameters["gradient_booster"] = Object();
|
||||||
auto& gradient_booster = learner_parameters["gradient_booster"];
|
auto& gradient_booster = learner_parameters["gradient_booster"];
|
||||||
gbm_->SaveConfig(&gradient_booster);
|
gbm_->SaveConfig(&gradient_booster);
|
||||||
@@ -461,6 +481,7 @@ class LearnerImpl : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (header[0] == '{') {
|
if (header[0] == '{') {
|
||||||
|
// Dispatch to JSON
|
||||||
auto json_stream = common::FixedSizeStream(&fp);
|
auto json_stream = common::FixedSizeStream(&fp);
|
||||||
std::string buffer;
|
std::string buffer;
|
||||||
json_stream.Take(&buffer);
|
json_stream.Take(&buffer);
|
||||||
@@ -473,25 +494,10 @@ class LearnerImpl : public Learner {
|
|||||||
// read parameter
|
// read parameter
|
||||||
CHECK_EQ(fi->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
|
CHECK_EQ(fi->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
|
||||||
<< "BoostLearner: wrong model format";
|
<< "BoostLearner: wrong model format";
|
||||||
{
|
|
||||||
// backward compatibility code for compatible with old model type
|
CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format";
|
||||||
// for new model, Read(&name_obj_) is suffice
|
|
||||||
uint64_t len;
|
|
||||||
CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len));
|
|
||||||
if (len >= std::numeric_limits<unsigned>::max()) {
|
|
||||||
int gap;
|
|
||||||
CHECK_EQ(fi->Read(&gap, sizeof(gap)), sizeof(gap))
|
|
||||||
<< "BoostLearner: wrong model format";
|
|
||||||
len = len >> static_cast<uint64_t>(32UL);
|
|
||||||
}
|
|
||||||
if (len != 0) {
|
|
||||||
tparam_.objective.resize(len);
|
|
||||||
CHECK_EQ(fi->Read(&tparam_.objective[0], len), len)
|
|
||||||
<< "BoostLearner: wrong model format";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
|
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
|
||||||
// duplicated code with LazyInitModel
|
|
||||||
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
|
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
|
||||||
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_,
|
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_,
|
||||||
&learner_model_param_, cache_));
|
&learner_model_param_, cache_));
|
||||||
@@ -510,34 +516,57 @@ class LearnerImpl : public Learner {
|
|||||||
}
|
}
|
||||||
attributes_ = std::map<std::string, std::string>(attr.begin(), attr.end());
|
attributes_ = std::map<std::string, std::string>(attr.begin(), attr.end());
|
||||||
}
|
}
|
||||||
if (tparam_.objective == "count:poisson") {
|
bool warn_old_model { false };
|
||||||
std::string max_delta_step;
|
if (attributes_.find("count_poisson_max_delta_step") != attributes_.cend()) {
|
||||||
fi->Read(&max_delta_step);
|
// Loading model from < 1.0.0, objective is not saved.
|
||||||
cfg_["max_delta_step"] = max_delta_step;
|
cfg_["max_delta_step"] = attributes_.at("count_poisson_max_delta_step");
|
||||||
|
attributes_.erase("count_poisson_max_delta_step");
|
||||||
|
warn_old_model = true;
|
||||||
|
} else {
|
||||||
|
warn_old_model = false;
|
||||||
}
|
}
|
||||||
if (mparam_.contain_eval_metrics != 0) {
|
|
||||||
std::vector<std::string> metr;
|
if (mparam_.major_version >= 1) {
|
||||||
fi->Read(&metr);
|
learner_model_param_ = LearnerModelParam(mparam_,
|
||||||
for (auto name : metr) {
|
obj_->ProbToMargin(mparam_.base_score));
|
||||||
metrics_.emplace_back(Metric::Create(name, &generic_parameters_));
|
} else {
|
||||||
|
// Before 1.0.0, base_score is saved as a transformed value, and there's no version
|
||||||
|
// attribute in the saved model.
|
||||||
|
learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score);
|
||||||
|
warn_old_model = true;
|
||||||
|
}
|
||||||
|
if (attributes_.find("objective") != attributes_.cend()) {
|
||||||
|
auto obj_str = attributes_.at("objective");
|
||||||
|
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
|
||||||
|
obj_->LoadConfig(j_obj);
|
||||||
|
attributes_.erase("objective");
|
||||||
|
} else {
|
||||||
|
warn_old_model = true;
|
||||||
|
}
|
||||||
|
if (attributes_.find("metrics") != attributes_.cend()) {
|
||||||
|
auto metrics_str = attributes_.at("metrics");
|
||||||
|
std::vector<std::string> names { common::Split(metrics_str, ';') };
|
||||||
|
attributes_.erase("metrics");
|
||||||
|
for (auto const& n : names) {
|
||||||
|
this->SetParam(kEvalMetric, n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (warn_old_model) {
|
||||||
|
LOG(WARNING) << "Loading model from XGBoost < 1.0.0, consider saving it "
|
||||||
|
"again for improved compatibility";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Renew the version.
|
||||||
|
mparam_.major_version = std::get<0>(Version::Self());
|
||||||
|
mparam_.minor_version = std::get<1>(Version::Self());
|
||||||
|
|
||||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||||
|
|
||||||
auto n = tparam_.__DICT__();
|
auto n = tparam_.__DICT__();
|
||||||
cfg_.insert(n.cbegin(), n.cend());
|
cfg_.insert(n.cbegin(), n.cend());
|
||||||
|
|
||||||
Args args = {cfg_.cbegin(), cfg_.cend()};
|
|
||||||
generic_parameters_.UpdateAllowUnknown(args);
|
|
||||||
gbm_->Configure(args);
|
|
||||||
obj_->Configure({cfg_.begin(), cfg_.end()});
|
|
||||||
|
|
||||||
for (auto& p_metric : metrics_) {
|
|
||||||
p_metric->Configure({cfg_.begin(), cfg_.end()});
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy dsplit from config since it will not run again during restore
|
// copy dsplit from config since it will not run again during restore
|
||||||
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
||||||
tparam_.dsplit = DataSplitMode::kRow;
|
tparam_.dsplit = DataSplitMode::kRow;
|
||||||
@@ -554,15 +583,8 @@ class LearnerImpl : public Learner {
|
|||||||
void SaveModel(dmlc::Stream* fo) const override {
|
void SaveModel(dmlc::Stream* fo) const override {
|
||||||
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
|
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
|
||||||
std::vector<std::pair<std::string, std::string> > extra_attr;
|
std::vector<std::pair<std::string, std::string> > extra_attr;
|
||||||
// extra attributed to be added just before saving
|
|
||||||
if (tparam_.objective == "count:poisson") {
|
|
||||||
auto it = cfg_.find("max_delta_step");
|
|
||||||
if (it != cfg_.end()) {
|
|
||||||
// write `max_delta_step` parameter as extra attribute of booster
|
|
||||||
mparam.contain_extra_attrs = 1;
|
mparam.contain_extra_attrs = 1;
|
||||||
extra_attr.emplace_back("count_poisson_max_delta_step", it->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
{
|
{
|
||||||
std::vector<std::string> saved_params;
|
std::vector<std::string> saved_params;
|
||||||
// check if rabit_bootstrap_cache were set to non zero before adding to checkpoint
|
// check if rabit_bootstrap_cache were set to non zero before adding to checkpoint
|
||||||
@@ -579,6 +601,24 @@ class LearnerImpl : public Learner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
// Similar to JSON model IO, we save the objective.
|
||||||
|
Json j_obj { Object() };
|
||||||
|
obj_->SaveConfig(&j_obj);
|
||||||
|
std::string obj_doc;
|
||||||
|
Json::Dump(j_obj, &obj_doc);
|
||||||
|
extra_attr.emplace_back("objective", obj_doc);
|
||||||
|
}
|
||||||
|
// As of 1.0.0, JVM Package and R Package uses Save/Load model for serialization.
|
||||||
|
// Remove this part once they are ported to use actual serialization methods.
|
||||||
|
if (mparam.contain_eval_metrics != 0) {
|
||||||
|
std::stringstream os;
|
||||||
|
for (auto& ev : metrics_) {
|
||||||
|
os << ev->Name() << ";";
|
||||||
|
}
|
||||||
|
extra_attr.emplace_back("metrics", os.str());
|
||||||
|
}
|
||||||
|
|
||||||
fo->Write(&mparam, sizeof(LearnerModelParamLegacy));
|
fo->Write(&mparam, sizeof(LearnerModelParamLegacy));
|
||||||
fo->Write(tparam_.objective);
|
fo->Write(tparam_.objective);
|
||||||
fo->Write(tparam_.booster);
|
fo->Write(tparam_.booster);
|
||||||
@@ -591,25 +631,6 @@ class LearnerImpl : public Learner {
|
|||||||
fo->Write(std::vector<std::pair<std::string, std::string>>(
|
fo->Write(std::vector<std::pair<std::string, std::string>>(
|
||||||
attr.begin(), attr.end()));
|
attr.begin(), attr.end()));
|
||||||
}
|
}
|
||||||
if (tparam_.objective == "count:poisson") {
|
|
||||||
auto it = cfg_.find("max_delta_step");
|
|
||||||
if (it != cfg_.end()) {
|
|
||||||
fo->Write(it->second);
|
|
||||||
} else {
|
|
||||||
// recover value of max_delta_step from extra attributes
|
|
||||||
auto it2 = attributes_.find("count_poisson_max_delta_step");
|
|
||||||
const std::string max_delta_step
|
|
||||||
= (it2 != attributes_.end()) ? it2->second : kMaxDeltaStepDefaultValue;
|
|
||||||
fo->Write(max_delta_step);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (mparam.contain_eval_metrics != 0) {
|
|
||||||
std::vector<std::string> metr;
|
|
||||||
for (auto& ev : metrics_) {
|
|
||||||
metr.emplace_back(ev->Name());
|
|
||||||
}
|
|
||||||
fo->Write(metr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Save(dmlc::Stream* fo) const override {
|
void Save(dmlc::Stream* fo) const override {
|
||||||
@@ -663,11 +684,13 @@ class LearnerImpl : public Learner {
|
|||||||
|
|
||||||
If you are loading a serialized model (like pickle in Python) generated by older
|
If you are loading a serialized model (like pickle in Python) generated by older
|
||||||
XGBoost, please export the model by calling `Booster.save_model` from that version
|
XGBoost, please export the model by calling `Booster.save_model` from that version
|
||||||
first, then load it back in current version. See:
|
first, then load it back in current version. There's a simple script for helping
|
||||||
|
the process. See:
|
||||||
|
|
||||||
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
|
||||||
|
|
||||||
for more details about differences between saving model and serializing.
|
for reference to the script, and more details about differences between saving model and
|
||||||
|
serializing.
|
||||||
|
|
||||||
)doc";
|
)doc";
|
||||||
int64_t sz {-1};
|
int64_t sz {-1};
|
||||||
@@ -856,7 +879,8 @@ class LearnerImpl : public Learner {
|
|||||||
|
|
||||||
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
|
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
|
||||||
// Once binary IO is gone, NONE of these config is useful.
|
// Once binary IO is gone, NONE of these config is useful.
|
||||||
if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0") {
|
if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0" &&
|
||||||
|
tparam_.objective != "multi:softprob") {
|
||||||
cfg_["num_output_group"] = cfg_["num_class"];
|
cfg_["num_output_group"] = cfg_["num_class"];
|
||||||
if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) {
|
if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) {
|
||||||
tparam_.objective = "multi:softmax";
|
tparam_.objective = "multi:softmax";
|
||||||
@@ -921,7 +945,6 @@ class LearnerImpl : public Learner {
|
|||||||
}
|
}
|
||||||
CHECK_NE(mparam_.num_feature, 0)
|
CHECK_NE(mparam_.num_feature, 0)
|
||||||
<< "0 feature is supplied. Are you using raw Booster interface?";
|
<< "0 feature is supplied. Are you using raw Booster interface?";
|
||||||
learner_model_param_.num_feature = mparam_.num_feature;
|
|
||||||
// Remove these once binary IO is gone.
|
// Remove these once binary IO is gone.
|
||||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ ARG CMAKE_VERSION=3.12
|
|||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
ENV DEBIAN_FRONTEND noninteractive
|
ENV DEBIAN_FRONTEND noninteractive
|
||||||
|
SHELL ["/bin/bash", "-c"] # Use Bash as shell
|
||||||
|
|
||||||
# Install all basic requirements
|
# Install all basic requirements
|
||||||
RUN \
|
RUN \
|
||||||
@@ -19,10 +20,17 @@ ENV PATH=/opt/python/bin:$PATH
|
|||||||
|
|
||||||
ENV GOSU_VERSION 1.10
|
ENV GOSU_VERSION 1.10
|
||||||
|
|
||||||
# Install Python packages
|
# Create new Conda environment with Python 3.5
|
||||||
|
RUN conda create -n py35 python=3.5 && \
|
||||||
|
source activate py35 && \
|
||||||
|
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
|
||||||
|
source deactivate
|
||||||
|
|
||||||
|
# Install Python packages in default env
|
||||||
RUN \
|
RUN \
|
||||||
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh recommonmark guzzle_sphinx_theme mock \
|
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh \
|
||||||
breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 jsonschema && \
|
recommonmark guzzle_sphinx_theme mock breathe graphviz \
|
||||||
|
pytest scikit-learn wheel kubernetes urllib3 jsonschema boto3 && \
|
||||||
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \
|
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \
|
||||||
pip install "dask[complete]"
|
pip install "dask[complete]"
|
||||||
|
|
||||||
|
|||||||
@@ -5,31 +5,35 @@ set -x
|
|||||||
suite=$1
|
suite=$1
|
||||||
|
|
||||||
# Install XGBoost Python package
|
# Install XGBoost Python package
|
||||||
wheel_found=0
|
function install_xgboost {
|
||||||
for file in python-package/dist/*.whl
|
wheel_found=0
|
||||||
do
|
for file in python-package/dist/*.whl
|
||||||
|
do
|
||||||
if [ -e "${file}" ]
|
if [ -e "${file}" ]
|
||||||
then
|
then
|
||||||
pip install --user "${file}"
|
pip install --user "${file}"
|
||||||
wheel_found=1
|
wheel_found=1
|
||||||
break # need just one
|
break # need just one
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
if [ "$wheel_found" -eq 0 ]
|
if [ "$wheel_found" -eq 0 ]
|
||||||
then
|
then
|
||||||
pushd .
|
pushd .
|
||||||
cd python-package
|
cd python-package
|
||||||
python setup.py install --user
|
python setup.py install --user
|
||||||
popd
|
popd
|
||||||
fi
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
# Run specified test suite
|
# Run specified test suite
|
||||||
case "$suite" in
|
case "$suite" in
|
||||||
gpu)
|
gpu)
|
||||||
|
install_xgboost
|
||||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu
|
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu
|
||||||
;;
|
;;
|
||||||
|
|
||||||
mgpu)
|
mgpu)
|
||||||
|
install_xgboost
|
||||||
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu
|
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu
|
||||||
cd tests/distributed
|
cd tests/distributed
|
||||||
./runtests-gpu.sh
|
./runtests-gpu.sh
|
||||||
@@ -39,17 +43,25 @@ case "$suite" in
|
|||||||
|
|
||||||
cudf)
|
cudf)
|
||||||
source activate cudf_test
|
source activate cudf_test
|
||||||
|
install_xgboost
|
||||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py tests/python-gpu/test_from_cupy.py
|
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py tests/python-gpu/test_from_cupy.py
|
||||||
;;
|
;;
|
||||||
|
|
||||||
cpu)
|
cpu)
|
||||||
|
install_xgboost
|
||||||
pytest -v -s --fulltrace tests/python
|
pytest -v -s --fulltrace tests/python
|
||||||
cd tests/distributed
|
cd tests/distributed
|
||||||
./runtests.sh
|
./runtests.sh
|
||||||
;;
|
;;
|
||||||
|
|
||||||
|
cpu-py35)
|
||||||
|
source activate py35
|
||||||
|
install_xgboost
|
||||||
|
pytest -v -s --fulltrace tests/python
|
||||||
|
;;
|
||||||
|
|
||||||
*)
|
*)
|
||||||
echo "Usage: $0 {gpu|mgpu|cudf|cpu}"
|
echo "Usage: $0 {gpu|mgpu|cudf|cpu|cpu-py35}"
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ TEST(Version, Basic) {
|
|||||||
|
|
||||||
ptr = 0;
|
ptr = 0;
|
||||||
v = std::stoi(str, &ptr);
|
v = std::stoi(str, &ptr);
|
||||||
ASSERT_EQ(v, XGBOOST_VER_MINOR) << "patch: " << v;;
|
ASSERT_EQ(v, XGBOOST_VER_PATCH) << "patch: " << v;;
|
||||||
|
|
||||||
str = str.substr(ptr);
|
str = str.substr(ptr);
|
||||||
ASSERT_EQ(str.size(), 0);
|
ASSERT_EQ(str.size(), 0);
|
||||||
|
|||||||
@@ -180,6 +180,41 @@ TEST(Learner, JsonModelIO) {
|
|||||||
delete pp_dmat;
|
delete pp_dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Learner, BinaryModelIO) {
|
||||||
|
size_t constexpr kRows = 8;
|
||||||
|
int32_t constexpr kIters = 4;
|
||||||
|
auto pp_dmat = CreateDMatrix(kRows, 10, 0);
|
||||||
|
std::shared_ptr<DMatrix> p_dmat {*pp_dmat};
|
||||||
|
p_dmat->Info().labels_.Resize(kRows);
|
||||||
|
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
|
||||||
|
learner->SetParam("eval_metric", "rmsle");
|
||||||
|
learner->Configure();
|
||||||
|
for (int32_t iter = 0; iter < kIters; ++iter) {
|
||||||
|
learner->UpdateOneIter(iter, p_dmat.get());
|
||||||
|
}
|
||||||
|
dmlc::TemporaryDirectory tempdir;
|
||||||
|
std::string const fname = tempdir.path + "binary_model_io.bin";
|
||||||
|
{
|
||||||
|
// Make sure the write is complete before loading.
|
||||||
|
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||||
|
learner->SaveModel(fo.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
learner.reset(Learner::Create({p_dmat}));
|
||||||
|
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
|
||||||
|
learner->LoadModel(fi.get());
|
||||||
|
learner->Configure();
|
||||||
|
Json config { Object() };
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
std::string config_str;
|
||||||
|
Json::Dump(config, &config_str);
|
||||||
|
ASSERT_NE(config_str.find("rmsle"), std::string::npos);
|
||||||
|
ASSERT_EQ(config_str.find("WARNING"), std::string::npos);
|
||||||
|
|
||||||
|
delete pp_dmat;
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
// Tests for automatic GPU configuration.
|
// Tests for automatic GPU configuration.
|
||||||
TEST(Learner, GPUConfiguration) {
|
TEST(Learner, GPUConfiguration) {
|
||||||
|
|||||||
148
tests/python/generate_models.py
Normal file
148
tests/python/generate_models.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import xgboost
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
kRounds = 2
|
||||||
|
kRows = 1000
|
||||||
|
kCols = 4
|
||||||
|
kForests = 2
|
||||||
|
kMaxDepth = 2
|
||||||
|
kClasses = 3
|
||||||
|
|
||||||
|
X = np.random.randn(kRows, kCols)
|
||||||
|
w = np.random.uniform(size=kRows)
|
||||||
|
|
||||||
|
version = xgboost.__version__
|
||||||
|
|
||||||
|
np.random.seed(1994)
|
||||||
|
target_dir = 'models'
|
||||||
|
|
||||||
|
|
||||||
|
def booster_bin(model):
|
||||||
|
return os.path.join(target_dir,
|
||||||
|
'xgboost-' + version + '.' + model + '.bin')
|
||||||
|
|
||||||
|
|
||||||
|
def booster_json(model):
|
||||||
|
return os.path.join(target_dir,
|
||||||
|
'xgboost-' + version + '.' + model + '.json')
|
||||||
|
|
||||||
|
|
||||||
|
def skl_bin(model):
|
||||||
|
return os.path.join(target_dir,
|
||||||
|
'xgboost_scikit-' + version + '.' + model + '.bin')
|
||||||
|
|
||||||
|
|
||||||
|
def skl_json(model):
|
||||||
|
return os.path.join(target_dir,
|
||||||
|
'xgboost_scikit-' + version + '.' + model + '.json')
|
||||||
|
|
||||||
|
|
||||||
|
def generate_regression_model():
|
||||||
|
print('Regression')
|
||||||
|
y = np.random.randn(kRows)
|
||||||
|
|
||||||
|
data = xgboost.DMatrix(X, label=y, weight=w)
|
||||||
|
booster = xgboost.train({'tree_method': 'hist',
|
||||||
|
'num_parallel_tree': kForests,
|
||||||
|
'max_depth': kMaxDepth},
|
||||||
|
num_boost_round=kRounds, dtrain=data)
|
||||||
|
booster.save_model(booster_bin('reg'))
|
||||||
|
booster.save_model(booster_json('reg'))
|
||||||
|
|
||||||
|
reg = xgboost.XGBRegressor(tree_method='hist',
|
||||||
|
num_parallel_tree=kForests,
|
||||||
|
max_depth=kMaxDepth,
|
||||||
|
n_estimators=kRounds)
|
||||||
|
reg.fit(X, y, w)
|
||||||
|
reg.save_model(skl_bin('reg'))
|
||||||
|
reg.save_model(skl_json('reg'))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_logistic_model():
|
||||||
|
print('Logistic')
|
||||||
|
y = np.random.randint(0, 2, size=kRows)
|
||||||
|
assert y.max() == 1 and y.min() == 0
|
||||||
|
|
||||||
|
data = xgboost.DMatrix(X, label=y, weight=w)
|
||||||
|
booster = xgboost.train({'tree_method': 'hist',
|
||||||
|
'num_parallel_tree': kForests,
|
||||||
|
'max_depth': kMaxDepth,
|
||||||
|
'objective': 'binary:logistic'},
|
||||||
|
num_boost_round=kRounds, dtrain=data)
|
||||||
|
booster.save_model(booster_bin('logit'))
|
||||||
|
booster.save_model(booster_json('logit'))
|
||||||
|
|
||||||
|
reg = xgboost.XGBClassifier(tree_method='hist',
|
||||||
|
num_parallel_tree=kForests,
|
||||||
|
max_depth=kMaxDepth,
|
||||||
|
n_estimators=kRounds)
|
||||||
|
reg.fit(X, y, w)
|
||||||
|
reg.save_model(skl_bin('logit'))
|
||||||
|
reg.save_model(skl_json('logit'))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_classification_model():
|
||||||
|
print('Classification')
|
||||||
|
y = np.random.randint(0, kClasses, size=kRows)
|
||||||
|
data = xgboost.DMatrix(X, label=y, weight=w)
|
||||||
|
booster = xgboost.train({'num_class': kClasses,
|
||||||
|
'tree_method': 'hist',
|
||||||
|
'num_parallel_tree': kForests,
|
||||||
|
'max_depth': kMaxDepth},
|
||||||
|
num_boost_round=kRounds, dtrain=data)
|
||||||
|
booster.save_model(booster_bin('cls'))
|
||||||
|
booster.save_model(booster_json('cls'))
|
||||||
|
|
||||||
|
cls = xgboost.XGBClassifier(tree_method='hist',
|
||||||
|
num_parallel_tree=kForests,
|
||||||
|
max_depth=kMaxDepth,
|
||||||
|
n_estimators=kRounds)
|
||||||
|
cls.fit(X, y, w)
|
||||||
|
cls.save_model(skl_bin('cls'))
|
||||||
|
cls.save_model(skl_json('cls'))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ranking_model():
|
||||||
|
print('Learning to Rank')
|
||||||
|
y = np.random.randint(5, size=kRows)
|
||||||
|
w = np.random.uniform(size=20)
|
||||||
|
g = np.repeat(50, 20)
|
||||||
|
|
||||||
|
data = xgboost.DMatrix(X, y, weight=w)
|
||||||
|
data.set_group(g)
|
||||||
|
booster = xgboost.train({'objective': 'rank:ndcg',
|
||||||
|
'num_parallel_tree': kForests,
|
||||||
|
'tree_method': 'hist',
|
||||||
|
'max_depth': kMaxDepth},
|
||||||
|
num_boost_round=kRounds,
|
||||||
|
dtrain=data)
|
||||||
|
booster.save_model(booster_bin('ltr'))
|
||||||
|
booster.save_model(booster_json('ltr'))
|
||||||
|
|
||||||
|
ranker = xgboost.sklearn.XGBRanker(n_estimators=kRounds,
|
||||||
|
tree_method='hist',
|
||||||
|
objective='rank:ndcg',
|
||||||
|
max_depth=kMaxDepth,
|
||||||
|
num_parallel_tree=kForests)
|
||||||
|
ranker.fit(X, y, g, sample_weight=w)
|
||||||
|
ranker.save_model(skl_bin('ltr'))
|
||||||
|
ranker.save_model(skl_json('ltr'))
|
||||||
|
|
||||||
|
|
||||||
|
def write_versions():
|
||||||
|
versions = {'numpy': np.__version__,
|
||||||
|
'xgboost': version}
|
||||||
|
with open(os.path.join(target_dir, 'version'), 'w') as fd:
|
||||||
|
fd.write(str(versions))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
os.mkdir(target_dir)
|
||||||
|
|
||||||
|
generate_regression_model()
|
||||||
|
generate_logistic_model()
|
||||||
|
generate_classification_model()
|
||||||
|
generate_ranking_model()
|
||||||
|
write_versions()
|
||||||
@@ -39,7 +39,7 @@ class TestBasic(unittest.TestCase):
|
|||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
param = {'max_depth': 2, 'eta': 1,
|
||||||
'objective': 'binary:logistic'}
|
'objective': 'binary:logistic'}
|
||||||
# specify validations set to watch performance
|
# specify validations set to watch performance
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import testing as tm
|
import testing as tm
|
||||||
import pytest
|
import pytest
|
||||||
|
import locale
|
||||||
|
|
||||||
dpath = 'demo/data/'
|
dpath = 'demo/data/'
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
@@ -284,25 +285,42 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertRaises(ValueError, bst.predict, dm1)
|
self.assertRaises(ValueError, bst.predict, dm1)
|
||||||
bst.predict(dm2) # success
|
bst.predict(dm2) # success
|
||||||
|
|
||||||
|
def test_model_binary_io(self):
|
||||||
|
model_path = 'test_model_binary_io.bin'
|
||||||
|
parameters = {'tree_method': 'hist', 'booster': 'gbtree',
|
||||||
|
'scale_pos_weight': '0.5'}
|
||||||
|
X = np.random.random((10, 3))
|
||||||
|
y = np.random.random((10,))
|
||||||
|
dtrain = xgb.DMatrix(X, y)
|
||||||
|
bst = xgb.train(parameters, dtrain, num_boost_round=2)
|
||||||
|
bst.save_model(model_path)
|
||||||
|
bst = xgb.Booster(model_file=model_path)
|
||||||
|
os.remove(model_path)
|
||||||
|
config = json.loads(bst.save_config())
|
||||||
|
assert float(config['learner']['objective'][
|
||||||
|
'reg_loss_param']['scale_pos_weight']) == 0.5
|
||||||
|
|
||||||
def test_model_json_io(self):
|
def test_model_json_io(self):
|
||||||
model_path = './model.json'
|
loc = locale.getpreferredencoding(False)
|
||||||
|
model_path = 'test_model_json_io.json'
|
||||||
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
|
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
|
||||||
j_model = json_model(model_path, parameters)
|
j_model = json_model(model_path, parameters)
|
||||||
assert isinstance(j_model['learner'], dict)
|
assert isinstance(j_model['learner'], dict)
|
||||||
|
|
||||||
bst = xgb.Booster(model_file='./model.json')
|
bst = xgb.Booster(model_file=model_path)
|
||||||
|
|
||||||
bst.save_model(fname=model_path)
|
bst.save_model(fname=model_path)
|
||||||
with open('./model.json', 'r') as fd:
|
with open(model_path, 'r') as fd:
|
||||||
j_model = json.load(fd)
|
j_model = json.load(fd)
|
||||||
assert isinstance(j_model['learner'], dict)
|
assert isinstance(j_model['learner'], dict)
|
||||||
|
|
||||||
os.remove(model_path)
|
os.remove(model_path)
|
||||||
|
assert locale.getpreferredencoding(False) == loc
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_json_schema())
|
@pytest.mark.skipif(**tm.no_json_schema())
|
||||||
def test_json_schema(self):
|
def test_json_schema(self):
|
||||||
import jsonschema
|
import jsonschema
|
||||||
model_path = './model.json'
|
model_path = 'test_json_schema.json'
|
||||||
path = os.path.dirname(
|
path = os.path.dirname(
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
doc = os.path.join(path, 'doc', 'model.schema')
|
doc = os.path.join(path, 'doc', 'model.schema')
|
||||||
|
|||||||
130
tests/python/test_model_compatibility.py
Normal file
130
tests/python/test_model_compatibility.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import xgboost
|
||||||
|
import os
|
||||||
|
import generate_models as gm
|
||||||
|
import json
|
||||||
|
import zipfile
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_param_check(config):
|
||||||
|
assert config['learner']['learner_model_param']['num_feature'] == str(4)
|
||||||
|
assert config['learner']['learner_train_param']['booster'] == 'gbtree'
|
||||||
|
|
||||||
|
|
||||||
|
def run_booster_check(booster, name):
|
||||||
|
config = json.loads(booster.save_config())
|
||||||
|
run_model_param_check(config)
|
||||||
|
if name.find('cls') != -1:
|
||||||
|
assert (len(booster.get_dump()) == gm.kForests * gm.kRounds *
|
||||||
|
gm.kClasses)
|
||||||
|
assert float(
|
||||||
|
config['learner']['learner_model_param']['base_score']) == 0.5
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'multi:softmax'
|
||||||
|
elif name.find('logit') != -1:
|
||||||
|
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
|
||||||
|
assert config['learner']['learner_model_param']['num_class'] == str(0)
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'binary:logistic'
|
||||||
|
elif name.find('ltr') != -1:
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'rank:ndcg'
|
||||||
|
else:
|
||||||
|
assert name.find('reg') != -1
|
||||||
|
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
|
||||||
|
assert float(
|
||||||
|
config['learner']['learner_model_param']['base_score']) == 0.5
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'reg:squarederror'
|
||||||
|
|
||||||
|
|
||||||
|
def run_scikit_model_check(name, path):
|
||||||
|
if name.find('reg') != -1:
|
||||||
|
reg = xgboost.XGBRegressor()
|
||||||
|
reg.load_model(path)
|
||||||
|
config = json.loads(reg.get_booster().save_config())
|
||||||
|
if name.find('0.90') != -1:
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'reg:linear'
|
||||||
|
else:
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'reg:squarederror'
|
||||||
|
assert (len(reg.get_booster().get_dump()) ==
|
||||||
|
gm.kRounds * gm.kForests)
|
||||||
|
run_model_param_check(config)
|
||||||
|
elif name.find('cls') != -1:
|
||||||
|
cls = xgboost.XGBClassifier()
|
||||||
|
cls.load_model(path)
|
||||||
|
if name.find('0.90') == -1:
|
||||||
|
assert len(cls.classes_) == gm.kClasses
|
||||||
|
assert len(cls._le.classes_) == gm.kClasses
|
||||||
|
assert cls.n_classes_ == gm.kClasses
|
||||||
|
assert (len(cls.get_booster().get_dump()) ==
|
||||||
|
gm.kRounds * gm.kForests * gm.kClasses), path
|
||||||
|
config = json.loads(cls.get_booster().save_config())
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'multi:softprob', path
|
||||||
|
run_model_param_check(config)
|
||||||
|
elif name.find('ltr') != -1:
|
||||||
|
ltr = xgboost.XGBRanker()
|
||||||
|
ltr.load_model(path)
|
||||||
|
assert (len(ltr.get_booster().get_dump()) ==
|
||||||
|
gm.kRounds * gm.kForests)
|
||||||
|
config = json.loads(ltr.get_booster().save_config())
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'rank:ndcg'
|
||||||
|
run_model_param_check(config)
|
||||||
|
elif name.find('logit') != -1:
|
||||||
|
logit = xgboost.XGBClassifier()
|
||||||
|
logit.load_model(path)
|
||||||
|
assert (len(logit.get_booster().get_dump()) ==
|
||||||
|
gm.kRounds * gm.kForests)
|
||||||
|
config = json.loads(logit.get_booster().save_config())
|
||||||
|
assert config['learner']['learner_train_param'][
|
||||||
|
'objective'] == 'binary:logistic'
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.ci
|
||||||
|
def test_model_compatibility():
|
||||||
|
'''Test model compatibility, can only be run on CI as others don't
|
||||||
|
have the credentials.
|
||||||
|
|
||||||
|
'''
|
||||||
|
path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
path = os.path.join(path, 'models')
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
import botocore
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip(
|
||||||
|
'Skiping compatibility tests as boto3 is not installed.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
s3_bucket = boto3.resource('s3').Bucket('xgboost-ci-jenkins-artifacts')
|
||||||
|
zip_path = 'xgboost_model_compatibility_test.zip'
|
||||||
|
s3_bucket.download_file(zip_path, zip_path)
|
||||||
|
except botocore.exceptions.NoCredentialsError:
|
||||||
|
pytest.skip(
|
||||||
|
'Skiping compatibility tests as running on non-CI environment.')
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_path, 'r') as z:
|
||||||
|
z.extractall(path)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
os.path.join(root, f) for root, subdir, files in os.walk(path)
|
||||||
|
for f in files
|
||||||
|
if f != 'version'
|
||||||
|
]
|
||||||
|
assert models
|
||||||
|
|
||||||
|
for path in models:
|
||||||
|
name = os.path.basename(path)
|
||||||
|
if name.startswith('xgboost-'):
|
||||||
|
booster = xgboost.Booster(model_file=path)
|
||||||
|
run_booster_check(booster, name)
|
||||||
|
elif name.startswith('xgboost_scikit'):
|
||||||
|
run_scikit_model_check(name, path)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
@@ -115,7 +115,6 @@ class TestRanking(unittest.TestCase):
|
|||||||
# model training parameters
|
# model training parameters
|
||||||
cls.params = {'objective': 'rank:pairwise',
|
cls.params = {'objective': 'rank:pairwise',
|
||||||
'booster': 'gbtree',
|
'booster': 'gbtree',
|
||||||
'silent': 0,
|
|
||||||
'eval_metric': ['ndcg']
|
'eval_metric': ['ndcg']
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,7 +152,8 @@ class TestRanking(unittest.TestCase):
|
|||||||
Test cross-validation with a group specified
|
Test cross-validation with a group specified
|
||||||
"""
|
"""
|
||||||
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
|
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
|
||||||
early_stopping_rounds=10, shuffle=False, nfold=10, as_pandas=False)
|
early_stopping_rounds=10, shuffle=False, nfold=10,
|
||||||
|
as_pandas=False)
|
||||||
assert isinstance(cv, dict)
|
assert isinstance(cv, dict)
|
||||||
assert len(cv) == 4
|
assert len(cv) == 4
|
||||||
|
|
||||||
|
|||||||
@@ -10,18 +10,20 @@ if sys.platform.startswith("win"):
|
|||||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from distributed.utils_test import client, loop, cluster_fixture
|
from distributed import LocalCluster, Client
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
from xgboost.dask import DaskDMatrix
|
from xgboost.dask import DaskDMatrix
|
||||||
except ImportError:
|
except ImportError:
|
||||||
client = None
|
LocalCluster = None
|
||||||
loop = None
|
Client = None
|
||||||
cluster_fixture = None
|
dd = None
|
||||||
pass
|
da = None
|
||||||
|
DaskDMatrix = None
|
||||||
|
|
||||||
kRows = 1000
|
kRows = 1000
|
||||||
kCols = 10
|
kCols = 10
|
||||||
|
kWorkers = 5
|
||||||
|
|
||||||
|
|
||||||
def generate_array():
|
def generate_array():
|
||||||
@@ -31,7 +33,9 @@ def generate_array():
|
|||||||
return X, y
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
def test_from_dask_dataframe(client):
|
def test_from_dask_dataframe():
|
||||||
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
|
|
||||||
X = dd.from_dask_array(X)
|
X = dd.from_dask_array(X)
|
||||||
@@ -51,11 +55,13 @@ def test_from_dask_dataframe(client):
|
|||||||
# evals_result is not supported in dask interface.
|
# evals_result is not supported in dask interface.
|
||||||
xgb.dask.train(
|
xgb.dask.train(
|
||||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||||
|
# force prediction to be computed
|
||||||
prediction = prediction.compute() # force prediction to be computed
|
prediction = prediction.compute()
|
||||||
|
|
||||||
|
|
||||||
def test_from_dask_array(client):
|
def test_from_dask_array():
|
||||||
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
dtrain = DaskDMatrix(client, X, y)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
# results is {'booster': Booster, 'history': {...}}
|
# results is {'booster': Booster, 'history': {...}}
|
||||||
@@ -65,11 +71,13 @@ def test_from_dask_array(client):
|
|||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
assert isinstance(prediction, da.Array)
|
assert isinstance(prediction, da.Array)
|
||||||
|
# force prediction to be computed
|
||||||
prediction = prediction.compute() # force prediction to be computed
|
prediction = prediction.compute()
|
||||||
|
|
||||||
|
|
||||||
def test_regressor(client):
|
def test_dask_regressor():
|
||||||
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||||
regressor.set_params(tree_method='hist')
|
regressor.set_params(tree_method='hist')
|
||||||
@@ -89,10 +97,13 @@ def test_regressor(client):
|
|||||||
assert len(history['validation_0']['rmse']) == 2
|
assert len(history['validation_0']['rmse']) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_classifier(client):
|
def test_dask_classifier():
|
||||||
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
y = (y * 10).astype(np.int32)
|
y = (y * 10).astype(np.int32)
|
||||||
classifier = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2)
|
classifier = xgb.dask.DaskXGBClassifier(
|
||||||
|
verbosity=1, n_estimators=2)
|
||||||
classifier.client = client
|
classifier.client = client
|
||||||
classifier.fit(X, y, eval_set=[(X, y)])
|
classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = classifier.predict(X)
|
prediction = classifier.predict(X)
|
||||||
@@ -164,11 +175,15 @@ def run_empty_dmatrix(client, parameters):
|
|||||||
# No test for Exact, as empty DMatrix handling are mostly for distributed
|
# No test for Exact, as empty DMatrix handling are mostly for distributed
|
||||||
# environment and Exact doesn't support it.
|
# environment and Exact doesn't support it.
|
||||||
|
|
||||||
def test_empty_dmatrix_hist(client):
|
def test_empty_dmatrix_hist():
|
||||||
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
parameters = {'tree_method': 'hist'}
|
parameters = {'tree_method': 'hist'}
|
||||||
run_empty_dmatrix(client, parameters)
|
run_empty_dmatrix(client, parameters)
|
||||||
|
|
||||||
|
|
||||||
def test_empty_dmatrix_approx(client):
|
def test_empty_dmatrix_approx():
|
||||||
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
parameters = {'tree_method': 'approx'}
|
parameters = {'tree_method': 'approx'}
|
||||||
run_empty_dmatrix(client, parameters)
|
run_empty_dmatrix(client, parameters)
|
||||||
|
|||||||
@@ -490,6 +490,13 @@ def test_kwargs():
|
|||||||
assert clf.get_params()['n_estimators'] == 1000
|
assert clf.get_params()['n_estimators'] == 1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_kwargs_error():
|
||||||
|
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
||||||
|
assert isinstance(clf, xgb.XGBClassifier)
|
||||||
|
|
||||||
|
|
||||||
def test_kwargs_grid_search():
|
def test_kwargs_grid_search():
|
||||||
from sklearn.model_selection import GridSearchCV
|
from sklearn.model_selection import GridSearchCV
|
||||||
from sklearn import datasets
|
from sklearn import datasets
|
||||||
@@ -510,13 +517,6 @@ def test_kwargs_grid_search():
|
|||||||
assert len(means) == len(set(means))
|
assert len(means) == len(set(means))
|
||||||
|
|
||||||
|
|
||||||
def test_kwargs_error():
|
|
||||||
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
|
||||||
assert isinstance(clf, xgb.XGBClassifier)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sklearn_clone():
|
def test_sklearn_clone():
|
||||||
from sklearn.base import clone
|
from sklearn.base import clone
|
||||||
|
|
||||||
@@ -525,6 +525,17 @@ def test_sklearn_clone():
|
|||||||
clone(clf)
|
clone(clf)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sklearn_get_default_params():
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
digits_2class = load_digits(2)
|
||||||
|
X = digits_2class['data']
|
||||||
|
y = digits_2class['target']
|
||||||
|
cls = xgb.XGBClassifier()
|
||||||
|
assert cls.get_params()['base_score'] is None
|
||||||
|
cls.fit(X[:4, ...], y[:4, ...])
|
||||||
|
assert cls.get_params()['base_score'] is not None
|
||||||
|
|
||||||
|
|
||||||
def test_validation_weights_xgbmodel():
|
def test_validation_weights_xgbmodel():
|
||||||
from sklearn.datasets import make_hastie_10_2
|
from sklearn.datasets import make_hastie_10_2
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user