Compare commits

...

16 Commits

Author SHA1 Message Date
Hyunsu Cho
74e2f652de Enforce only major version in JSON model schema 2020-02-21 07:57:45 +00:00
Hyunsu Cho
e02fff53f2 Change version_config.h too 2020-02-21 07:50:41 +00:00
Hyunsu Cho
fcb2efbadd Fix a unit test that mistook MINOR ver for PATCH ver 2020-02-21 07:11:59 +00:00
Hyunsu Cho
f4621f09c7 Release 1.0.1 to add #5330 2020-02-20 22:56:32 -08:00
Philip Hyunsu Cho
bf1b2cbfa2 Remove f-string, since it's not supported by Python 3.5 (#5330)
* Remove f-string, since it's not supported by Python 3.5

* Add Python 3.5 to CI, to ensure compatibility

* Remove duplicated matplotlib

* Show deprecation notice for Python 3.5

* Fix lint

* Fix lint
2020-02-20 22:47:05 -08:00
Hyunsu Cho
d90e7b3117 Change version to 1.0.0 2020-02-20 05:02:47 +00:00
Jiaming Yuan
088c43d666 Fix changing locale. (#5314)
* Fix changing locale.

* Don't use locale guard.

As number parsing is implemented in house, we don't need locale.

* Update doc.
2020-02-17 13:01:48 +08:00
Hyunsu Cho
69fc8a632f Change version to 1.0.0rc2 2020-02-14 09:56:47 +00:00
Jiaming Yuan
213f4fa45a Fix loading old logit model, helper for converting old pickle. (#5281)
* Fix loading old logit model.
* Add a helper script for converting old pickle file.
* Add version as a model parameter.
* Remove the size check in R test to relax the size constraint.
* Add missing R doc for passing linting. Run devtools.
* Cleanup old model IO logic.
* Test compatibility on CI.
* Make the argument as required.
2020-02-13 15:28:13 +08:00
Philip Hyunsu Cho
5ca21f252a Declare Python 3.8 support in setup.py (#5274) 2020-02-03 22:59:25 -08:00
Jiaming Yuan
eeb67c3d52 Avoid dask test fixtures. (#5270)
* Fix Travis OSX timeout.

* Fix classifier.
2020-02-03 22:57:36 -08:00
Jiaming Yuan
ed37fdb9c9 Export JSON config in get_params. (#5256) 2020-02-03 06:16:03 +00:00
Philip Hyunsu Cho
e7e522fb06 Remove use of std::cout from R package (#5261) 2020-02-03 06:15:39 +00:00
Nan Zhu
8e39a675be [jvm-packages] do not use multiple jobs to make checkpoints (#5082)
* temp

* temp

* tep

* address the comments

* fix stylistic issues

* fix

* external checkpoint
2020-02-02 21:48:42 -08:00
Jiaming Yuan
7f542d2198 Test model compatibility. (#5248)
* Add model compability tests.

* Typo.
2020-01-31 02:46:13 -08:00
Philip Cho
c8d32102fb Change version to 1.0.0rc1 2020-01-31 01:37:54 -08:00
51 changed files with 1280 additions and 608 deletions

View File

@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.12) cmake_minimum_required(VERSION 3.12)
project(xgboost LANGUAGES CXX C VERSION 1.0.0) project(xgboost LANGUAGES CXX C VERSION 1.0.1)
include(cmake/Utils.cmake) include(cmake/Utils.cmake)
list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules") list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules")
cmake_policy(SET CMP0022 NEW) cmake_policy(SET CMP0022 NEW)
@@ -49,7 +49,7 @@ option(USE_SANITIZER "Use santizer flags" OFF)
option(SANITIZER_PATH "Path to sanitizes.") option(SANITIZER_PATH "Path to sanitizes.")
set(ENABLED_SANITIZERS "address" "leak" CACHE STRING set(ENABLED_SANITIZERS "address" "leak" CACHE STRING
"Semicolon separated list of sanitizer names. E.g 'address;leak'. Supported sanitizers are "Semicolon separated list of sanitizer names. E.g 'address;leak'. Supported sanitizers are
address, leak and thread.") address, leak, undefined and thread.")
## Plugins ## Plugins
option(PLUGIN_LZ4 "Build lz4 plugin" OFF) option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF) option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)

1
Jenkinsfile vendored
View File

@@ -273,6 +273,7 @@ def TestPythonCPU() {
def docker_binary = "docker" def docker_binary = "docker"
sh """ sh """
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/test_python.sh cpu ${dockerRun} ${container_type} ${docker_binary} tests/ci_build/test_python.sh cpu
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/test_python.sh cpu-py35
""" """
deleteDir() deleteDir()
} }

View File

@@ -139,6 +139,8 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several #' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case. This option has no effect when either of predleaf, predcontrib, #' prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
#' or predinteraction flags is TRUE. #' or predinteraction flags is TRUE.
#' @param training whether is the prediction result used for training. For dart booster,
#' training predicting will perform dropout.
#' @param ... Parameters passed to \code{predict.xgb.Booster} #' @param ... Parameters passed to \code{predict.xgb.Booster}
#' #'
#' @details #' @details

View File

@@ -49,6 +49,9 @@ It will use all the trees by default (\code{NULL} value).}
prediction outputs per case. This option has no effect when either of predleaf, predcontrib, prediction outputs per case. This option has no effect when either of predleaf, predcontrib,
or predinteraction flags is TRUE.} or predinteraction flags is TRUE.}
\item{training}{whether is the prediction result used for training. For dart booster,
training predicting will perform dropout.}
\item{...}{Parameters passed to \code{predict.xgb.Booster}} \item{...}{Parameters passed to \code{predict.xgb.Booster}}
} }
\value{ \value{

View File

@@ -31,7 +31,6 @@ num_round <- 2
test_that("custom objective works", { test_that("custom objective works", {
bst <- xgb.train(param, dtrain, num_round, watchlist) bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster") expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1100)
expect_false(is.null(bst$evaluation_log)) expect_false(is.null(bst$evaluation_log))
expect_false(is.null(bst$evaluation_log$eval_error)) expect_false(is.null(bst$evaluation_log$eval_error))
expect_lt(bst$evaluation_log[num_round, eval_error], 0.03) expect_lt(bst$evaluation_log[num_round, eval_error], 0.03)
@@ -58,5 +57,4 @@ test_that("custom objective using DMatrix attr works", {
param$objective = logregobjattr param$objective = logregobjattr
bst <- xgb.train(param, dtrain, num_round, watchlist) bst <- xgb.train(param, dtrain, num_round, watchlist)
expect_equal(class(bst), "xgb.Booster") expect_equal(class(bst), "xgb.Booster")
expect_equal(length(bst$raw), 1100)
}) })

View File

@@ -1 +1 @@
@xgboost_VERSION_MAJOR@.@xgboost_VERSION_MINOR@.@xgboost_VERSION_PATCH@-SNAPSHOT @xgboost_VERSION_MAJOR@.@xgboost_VERSION_MINOR@.@xgboost_VERSION_PATCH@

View File

@@ -195,12 +195,22 @@
"properties": { "properties": {
"version": { "version": {
"type": "array", "type": "array",
"const": [ "items": [
1, {
0, "type": "number",
0 "const": 1
},
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"minimum": 0
}
], ],
"additionalItems": false "minItems": 3,
"maxItems": 3
}, },
"learner": { "learner": {
"type": "object", "type": "object",

View File

@@ -0,0 +1,79 @@
'''This is a simple script that converts a pickled XGBoost
Scikit-Learn interface object from 0.90 to a native model. Pickle
format is not stable as it's a direct serialization of Python object.
We advice not to use it when stability is needed.
'''
import pickle
import json
import os
import argparse
import numpy as np
import xgboost
import warnings
def save_label_encoder(le):
'''Save the label encoder in XGBClassifier'''
meta = dict()
for k, v in le.__dict__.items():
if isinstance(v, np.ndarray):
meta[k] = v.tolist()
else:
meta[k] = v
return meta
def xgboost_skl_90to100(skl_model):
'''Extract the model and related metadata in SKL model.'''
model = {}
with open(skl_model, 'rb') as fd:
old = pickle.load(fd)
if not isinstance(old, xgboost.XGBModel):
raise TypeError(
'The script only handes Scikit-Learn interface object')
# Save Scikit-Learn specific Python attributes into a JSON document.
for k, v in old.__dict__.items():
if k == '_le':
model[k] = save_label_encoder(v)
elif k == 'classes_':
model[k] = v.tolist()
elif k == '_Booster':
continue
else:
try:
json.dumps({k: v})
model[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
booster = old.get_booster()
# Store the JSON serialization as an attribute
booster.set_attr(scikit_learn=json.dumps(model))
# Save it into a native model.
i = 0
while True:
path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin'
if os.path.exists(path):
i += 1
continue
booster.save_model(path)
break
if __name__ == '__main__':
assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version'
' that generates this pickle.')
parser = argparse.ArgumentParser(
description=('A simple script to convert pickle generated by'
' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).'))
parser.add_argument(
'--old-pickle',
type=str,
help='Path to old pickle file of Scikit-Learn interface object. '
'Will output a native model converted from this pickle file',
required=True)
args = parser.parse_args()
xgboost_skl_90to100(args.old_pickle)

View File

@@ -91,7 +91,12 @@ Loading pickled file from different version of XGBoost
As noted, pickled model is neither portable nor stable, but in some cases the pickled As noted, pickled model is neither portable nor stable, but in some cases the pickled
models are valuable. One way to restore it in the future is to load it back with that models are valuable. One way to restore it in the future is to load it back with that
specific version of Python and XGBoost, export the model by calling `save_model`. specific version of Python and XGBoost, export the model by calling `save_model`. To help
easing the mitigation, we created a simple script for converting pickled XGBoost 0.90
Scikit-Learn interface object to XGBoost 1.0.0 native model. Please note that the script
suits simple use cases, and it's advised not to use pickle when stability is needed.
It's located in ``xgboost/doc/python`` with the name ``convert_090to100.py``. See
comments in the script for more details.
******************************************************** ********************************************************
Saving and Loading the internal parameters configuration Saving and Loading the internal parameters configuration
@@ -190,7 +195,9 @@ You can load it back to the model generated by same version of XGBoost by:
bst.load_config(config) bst.load_config(config)
This way users can study the internal representation more closely. This way users can study the internal representation more closely. Please note that some
JSON generators make use of locale dependent floating point serialization methods, which
is not supported by XGBoost.
************ ************
Future Plans Future Plans

View File

@@ -208,6 +208,8 @@ struct LearnerModelParam {
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep // As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
// this one as an immutable copy. // this one as an immutable copy.
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin); LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin);
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
bool Initialized() const { return num_feature != 0; }
}; };
} // namespace xgboost } // namespace xgboost

View File

@@ -6,6 +6,6 @@
#define XGBOOST_VER_MAJOR 1 #define XGBOOST_VER_MAJOR 1
#define XGBOOST_VER_MINOR 0 #define XGBOOST_VER_MINOR 0
#define XGBOOST_VER_PATCH 0 #define XGBOOST_VER_PATCH 1
#endif // XGBOOST_VERSION_CONFIG_H_ #endif // XGBOOST_VERSION_CONFIG_H_

View File

@@ -6,7 +6,7 @@
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId> <artifactId>xgboost-jvm_2.12</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>XGBoost JVM Package</name> <name>XGBoost JVM Package</name>
<description>JVM Package for XGBoost</description> <description>JVM Package for XGBoost</description>
@@ -37,6 +37,7 @@
<spark.version>2.4.3</spark.version> <spark.version>2.4.3</spark.version>
<scala.version>2.12.8</scala.version> <scala.version>2.12.8</scala.version>
<scala.binary.version>2.12</scala.binary.version> <scala.binary.version>2.12</scala.binary.version>
<hadoop.version>2.7.3</hadoop.version>
</properties> </properties>
<repositories> <repositories>
<repository> <repository>

View File

@@ -6,10 +6,10 @@
<parent> <parent>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId> <artifactId>xgboost-jvm_2.12</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
</parent> </parent>
<artifactId>xgboost4j-example_2.12</artifactId> <artifactId>xgboost4j-example_2.12</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<build> <build>
<plugins> <plugins>
@@ -26,7 +26,7 @@
<dependency> <dependency>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId> <artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>
@@ -37,7 +37,7 @@
<dependency> <dependency>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId> <artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>

View File

@@ -6,10 +6,10 @@
<parent> <parent>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId> <artifactId>xgboost-jvm_2.12</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
</parent> </parent>
<artifactId>xgboost4j-flink_2.12</artifactId> <artifactId>xgboost4j-flink_2.12</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
<build> <build>
<plugins> <plugins>
<plugin> <plugin>
@@ -26,7 +26,7 @@
<dependency> <dependency>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId> <artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>

View File

@@ -6,7 +6,7 @@
<parent> <parent>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId> <artifactId>xgboost-jvm_2.12</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
</parent> </parent>
<artifactId>xgboost4j-spark_2.12</artifactId> <artifactId>xgboost4j-spark_2.12</artifactId>
<build> <build>
@@ -24,7 +24,7 @@
<dependency> <dependency>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId> <artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>

View File

@@ -1,164 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
/**
* A class which allows user to save checkpoints every a few rounds. If a previous job fails,
* the job can restart training from a saved checkpoints instead of from scratch. This class
* provides interface and helper methods for the checkpoint functionality.
*
* NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
* checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
* for every a few iterations.
*
* @param sc the sparkContext object
* @param checkpointPath the hdfs path to store checkpoints
*/
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
private val logger = LogFactory.getLog("XGBoostSpark")
private val modelSuffix = ".model"
private def getPath(version: Int) = {
s"$checkpointPath/$version$modelSuffix"
}
private def getExistingVersions: Seq[Int] = {
val fs = FileSystem.get(sc.hadoopConfiguration)
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
Seq()
} else {
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
}
}
}
def cleanPath(): Unit = {
if (checkpointPath != "") {
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
}
}
/**
* Load existing checkpoint with the highest version as a Booster object
*
* @return the booster with the highest version, null if no checkpoints available.
*/
private[spark] def loadCheckpointAsBooster: Booster = {
val versions = getExistingVersions
if (versions.nonEmpty) {
val version = versions.max
val fullPath = getPath(version)
val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
logger.info(s"Start training from previous booster at $fullPath")
val booster = SXGBoost.loadModel(inputStream)
booster.booster.setVersion(version)
booster
} else {
null
}
}
/**
* Clean up all previous checkpoints and save a new checkpoint
*
* @param checkpoint the checkpoint to save as an XGBoostModel
*/
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(checkpoint.getVersion)
val outputStream = fs.create(new Path(fullPath), true)
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
checkpoint.saveModel(outputStream)
prevModelPaths.foreach(path => fs.delete(path, true))
}
/**
* Clean up checkpoint boosters with version higher than or equal to the round.
*
* @param round the number of rounds in the current training job
*/
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
higherVersions.foreach { version =>
val fs = FileSystem.get(sc.hadoopConfiguration)
fs.delete(new Path(getPath(version)), true)
}
}
/**
* Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
* and total number of rounds for the training. Concretely, the checkpoint rounds start with
* prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
* reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
* and the method returns Seq(round)
*
* @param checkpointInterval Period (in iterations) between checkpoints.
* @param round the total number of rounds for the training
* @return a seq of integers, each represent the index of round to save the checkpoints
*/
private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
if (checkpointPath.nonEmpty && checkpointInterval > 0) {
val prevRounds = getExistingVersions.map(_ / 2)
val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
(firstCheckpointRound until round by checkpointInterval) :+ round
} else if (checkpointInterval <= 0) {
Seq(round)
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
}
}
}
object CheckpointManager {
case class CheckpointParam(
checkpointPath: String,
checkpointInterval: Int,
skipCleanCheckpoint: Boolean)
private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None => ""
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
" an instance of String.")
}
val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
case None => false
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
}
}

View File

@@ -25,12 +25,13 @@ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener} import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
@@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see
private[this] case class XGBoostExecutionParams( private[this] case class XGBoostExecutionParams(
numWorkers: Int, numWorkers: Int,
round: Int, numRounds: Int,
useExternalMemory: Boolean, useExternalMemory: Boolean,
obj: ObjectiveTrait, obj: ObjectiveTrait,
eval: EvalTrait, eval: EvalTrait,
@@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams(
allowNonZeroForMissing: Boolean, allowNonZeroForMissing: Boolean,
trackerConf: TrackerConf, trackerConf: TrackerConf,
timeoutRequestWorkers: Long, timeoutRequestWorkers: Long,
checkpointParam: CheckpointParam, checkpointParam: Option[ExternalCheckpointParams],
xgbInputParams: XGBoostExecutionInputParams, xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean) { cacheTrainingSet: Boolean) {
@@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
.getOrElse("allow_non_zero_for_missing", false) .getOrElse("allow_non_zero_for_missing", false)
.asInstanceOf[Boolean] .asInstanceOf[Boolean]
validateSparkSslConf validateSparkSslConf
if (overridedParams.contains("tree_method")) { if (overridedParams.contains("tree_method")) {
require(overridedParams("tree_method") == "hist" || require(overridedParams("tree_method") == "hist" ||
overridedParams("tree_method") == "approx" || overridedParams("tree_method") == "approx" ||
@@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
" an instance of Long.") " an instance of Long.")
} }
val checkpointParam = val checkpointParam =
CheckpointManager.extractParams(overridedParams) ExternalCheckpointParams.extractParams(overridedParams)
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0) val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
.asInstanceOf[Double] .asInstanceOf[Double]
@@ -339,11 +339,9 @@ object XGBoost extends Serializable {
watches: Watches, watches: Watches,
xgbExecutionParam: XGBoostExecutionParams, xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
round: Int,
obj: ObjectiveTrait, obj: ObjectiveTrait,
eval: EvalTrait, eval: EvalTrait,
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = { prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
// to workaround the empty partitions in training dataset, // to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see // this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277) // (https://github.com/dmlc/xgboost/issues/1277)
@@ -357,14 +355,23 @@ object XGBoost extends Serializable {
rabitEnv.put("DMLC_TASK_ID", taskId) rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt) rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false") rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
try { try {
Rabit.init(rabitEnv) Rabit.init(rabitEnv)
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round, val externalCheckpointParams = xgbExecutionParam.checkpointParam
val booster = if (makeCheckpoint) {
SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
} else {
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap, metrics, obj, eval, watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster) earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
}
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} catch { } catch {
case xgbException: XGBoostError => case xgbException: XGBoostError =>
@@ -437,7 +444,6 @@ object XGBoost extends Serializable {
trainingData: RDD[XGBLabeledPoint], trainingData: RDD[XGBLabeledPoint],
xgbExecutionParams: XGBoostExecutionParams, xgbExecutionParams: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
prevBooster: Booster, prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) { if (evalSetsMap.isEmpty) {
@@ -446,8 +452,8 @@ object XGBoost extends Serializable {
processMissingValues(labeledPoints, xgbExecutionParams.missing, processMissingValues(labeledPoints, xgbExecutionParams.missing,
xgbExecutionParams.allowNonZeroForMissing), xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory)) getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) xgbExecutionParams.eval, prevBooster)
}).cache() }).cache()
} else { } else {
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers). coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
@@ -459,8 +465,8 @@ object XGBoost extends Serializable {
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing)) xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
}, },
getCacheDirName(xgbExecutionParams.useExternalMemory)) getCacheDirName(xgbExecutionParams.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) xgbExecutionParams.eval, prevBooster)
}.cache() }.cache()
} }
} }
@@ -469,7 +475,6 @@ object XGBoost extends Serializable {
trainingData: RDD[Array[XGBLabeledPoint]], trainingData: RDD[Array[XGBLabeledPoint]],
xgbExecutionParam: XGBoostExecutionParams, xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String], rabitEnv: java.util.Map[String, String],
checkpointRound: Int,
prevBooster: Booster, prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) { if (evalSetsMap.isEmpty) {
@@ -478,7 +483,7 @@ object XGBoost extends Serializable {
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing, processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
xgbExecutionParam.allowNonZeroForMissing), xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory)) getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster) xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
}).cache() }).cache()
} else { } else {
@@ -490,7 +495,7 @@ object XGBoost extends Serializable {
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing)) xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
}, },
getCacheDirName(xgbExecutionParam.useExternalMemory)) getCacheDirName(xgbExecutionParam.useExternalMemory))
buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj, xgbExecutionParam.obj,
xgbExecutionParam.eval, xgbExecutionParam.eval,
prevBooster) prevBooster)
@@ -529,33 +534,30 @@ object XGBoost extends Serializable {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext) val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val sc = trainingData.sparkContext val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
checkpointPath)
checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet, val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
hasGroup, xgbExecParams.numWorkers) hasGroup, xgbExecParams.numWorkers)
var prevBooster = checkpointManager.loadCheckpointAsBooster val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull
try { try {
// Train for every ${savingRound} rounds and save the partially completed booster // Train for every ${savingRound} rounds and save the partially completed booster
val producedBooster = checkpointManager.getCheckpointRounds(
xgbExecParams.checkpointParam.checkpointInterval,
xgbExecParams.round).map {
checkpointRound: Int =>
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
try { val (booster, metrics) = try {
val parallelismTracker = new SparkParallelismTracker(sc, val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers, xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers) xgbExecParams.numWorkers)
val rabitEnv = tracker.getWorkerEnvs
tracker.getWorkerEnvs().putAll(xgbRabitParams)
val boostersAndMetrics = if (hasGroup) { val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap) evalSetsMap)
} else { } else {
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap) prevBooster, evalSetsMap)
} }
val sparkJobThread = new Thread() { val sparkJobThread = new Thread() {
override def run() { override def run() {
@@ -569,20 +571,21 @@ object XGBoost extends Serializable {
logger.info(s"Rabit returns with exit code $trackerReturnVal") logger.info(s"Rabit returns with exit code $trackerReturnVal")
val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
boostersAndMetrics, sparkJobThread) boostersAndMetrics, sparkJobThread)
if (checkpointRound < xgbExecParams.round) {
prevBooster = booster
checkpointManager.updateCheckpoint(prevBooster)
}
(booster, metrics) (booster, metrics)
} finally { } finally {
tracker.stop() tracker.stop()
} }
}.last
// we should delete the checkpoint directory after a successful training // we should delete the checkpoint directory after a successful training
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) { xgbExecParams.checkpointParam.foreach {
cpParam =>
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanPath() checkpointManager.cleanPath()
} }
producedBooster }
(booster, metrics)
} catch { } catch {
case t: Throwable => case t: Throwable =>
// if the job was aborted due to an exception // if the job was aborted due to an exception

View File

@@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables { with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
def needDeterministicRepartitioning: Boolean = { def needDeterministicRepartitioning: Boolean = {
getCheckpointPath.nonEmpty && getCheckpointInterval > 0 getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
} }
} }

View File

@@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File import java.io.File
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
import org.scalatest.FunSuite import org.scalatest.{FunSuite, Ignore}
import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.fs.{FileSystem, Path}
class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest { class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
private lazy val (model4, model8) = { private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
Map[String, Any] = {
Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
"checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
}
private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val training = buildDataFrame(Classification.train) val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", val paramMap = produceParamMap(tmpPath, 2)
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training), (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training)) new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
} }
(tmpPath, model4, model8)
}
test("test update/load models") { test("test update/load models") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString val (tmpPath, model4, model8) = createNewModels()
val manager = new CheckpointManager(sc, tmpPath) val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster)
manager.updateCheckpoint(model4._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "4.model") assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4) assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
manager.updateCheckpoint(model8._booster) manager.updateCheckpoint(model8._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1) assert(files.length == 1)
assert(files.head.getPath.getName == "8.model") assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 8) assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
} }
test("test cleanUpHigherVersions") { test("test cleanUpHigherVersions") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString val (tmpPath, model4, model8) = createNewModels()
val manager = new CheckpointManager(sc, tmpPath)
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster) manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(round = 8) manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists()) assert(new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(round = 4) manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists()) assert(!new File(s"$tmpPath/8.model").exists())
} }
test("test checkpoint rounds") { test("test checkpoint rounds") {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString import scala.collection.JavaConverters._
val manager = new CheckpointManager(sc, tmpPath) val (tmpPath, model4, model8) = createNewModels()
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7)) val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7)) assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster) manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7)) assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
} }
@@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val testDM = new DMatrix(Classification.test.iterator) val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = produceParamMap(tmpPath, 2)
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map() val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
val skipCleanCheckpointMap = val skipCleanCheckpointMap =
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map() if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
skipCleanCheckpointMap
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training) val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM) val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
if (skipCleanCheckpoint) { if (skipCleanCheckpoint) {
// Check only one model is kept after training // Check only one model is kept after training
@@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model // Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster)) assert(error(tmpModel) >= error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster)) assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1) assert(error(nextModel._booster) < 0.1)
} else { } else {

View File

@@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
" stop the application") { " stop the application") {
val spark = ss val spark = ss
import spark.implicits._ import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector, // vector,
val testDF = Seq( val testDF = Seq(
@@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
"does not stop application") { "does not stop application") {
val spark = ss val spark = ss
import spark.implicits._ import spark.implicits._
ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector, // vector,
val testDF = Seq( val testDF = Seq(

View File

@@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError import ml.dmlc.xgboost4j.java.XGBoostError
import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap

View File

@@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{SparkConf, SparkContext} import org.scalatest.{FunSuite, Ignore}
import org.scalatest.FunSuite
class RabitRobustnessSuite extends FunSuite with PerTest { class RabitRobustnessSuite extends FunSuite with PerTest {

View File

@@ -6,13 +6,25 @@
<parent> <parent>
<groupId>ml.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId> <artifactId>xgboost-jvm_2.12</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
</parent> </parent>
<artifactId>xgboost4j_2.12</artifactId> <artifactId>xgboost4j_2.12</artifactId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<dependencies> <dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
<artifactId>junit</artifactId> <artifactId>junit</artifactId>

View File

@@ -0,0 +1,117 @@
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
public class ExternalCheckpointManager {
private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private Path checkpointPath;
private FileSystem fs;
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
if (checkpointPath == null || checkpointPath.isEmpty()) {
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
" empty checkpoint path");
}
this.checkpointPath = new Path(checkpointPath);
this.fs = fs;
}
private String getPath(int version) {
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
}
private List<Integer> getExistingVersions() throws IOException {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
.map(fileName -> Integer.valueOf(
fileName.substring(0, fileName.length() - modelSuffix.length())))
.collect(Collectors.toList());
}
}
public void cleanPath() throws IOException {
fs.delete(checkpointPath, true);
}
public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions();
if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster;
} else {
return null;
}
}
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
} catch (IOException e) {
logger.error("failed to delete outdated checkpoint at " + path, e);
}
});
}
}
public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
logger.error("failed to clean checkpoint from other training instance", e);
}
});
}
public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
throws IOException {
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
int firstCheckpointRound = prevRounds.stream()
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
arr.add(i);
}
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
}
}
}

View File

@@ -15,12 +15,16 @@
*/ */
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream;
import java.util.*; import java.util.*;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
/** /**
* trainer for xgboost * trainer for xgboost
@@ -108,35 +112,34 @@ public class XGBoost {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null); return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
} }
/** private static void saveCheckpoint(
* Train a booster given parameters. Booster booster,
* int iter,
* @param dtrain Data to be trained. Set<Integer> checkpointIterations,
* @param params Parameters. ExternalCheckpointManager ecm) throws XGBoostError {
* @param round Number of boosting iterations. try {
* @param watches a group of items to be evaluated during training, this allows user to watch if (checkpointIterations.contains(iter)) {
* performance on the validation set. ecm.updateCheckpoint(booster);
* @param metrics array containing the evaluation metrics for each matrix in watches for each }
* iteration } catch (Exception e) {
* @param earlyStoppingRounds if non-zero, training would be stopped logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
* after a specified number of consecutive throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
* goes to the unexpected direction in any evaluation metric. }
* @param obj customized objective }
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null. public static Booster trainAndSaveCheckpoint(
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain, DMatrix dtrain,
Map<String, Object> params, Map<String, Object> params,
int round, int numRounds,
Map<String, DMatrix> watches, Map<String, DMatrix> watches,
float[][] metrics, float[][] metrics,
IObjective obj, IObjective obj,
IEvaluation eval, IEvaluation eval,
int earlyStoppingRounds, int earlyStoppingRounds,
Booster booster) throws XGBoostError { Booster booster,
int checkpointInterval,
String checkpointPath,
FileSystem fs) throws XGBoostError, IOException {
//collect eval matrixs //collect eval matrixs
String[] evalNames; String[] evalNames;
DMatrix[] evalMats; DMatrix[] evalMats;
@@ -144,6 +147,11 @@ public class XGBoost {
int bestIteration; int bestIteration;
List<String> names = new ArrayList<String>(); List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>(); List<DMatrix> mats = new ArrayList<DMatrix>();
Set<Integer> checkpointIterations = new HashSet<>();
ExternalCheckpointManager ecm = null;
if (checkpointPath != null) {
ecm = new ExternalCheckpointManager(checkpointPath, fs);
}
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) { for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
names.add(evalEntry.getKey()); names.add(evalEntry.getKey());
@@ -158,7 +166,7 @@ public class XGBoost {
bestScore = Float.MAX_VALUE; bestScore = Float.MAX_VALUE;
} }
bestIteration = 0; bestIteration = 0;
metrics = metrics == null ? new float[evalNames.length][round] : metrics; metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
//collect all data matrixs //collect all data matrixs
DMatrix[] allMats; DMatrix[] allMats;
@@ -181,14 +189,19 @@ public class XGBoost {
booster.setParams(params); booster.setParams(params);
} }
//begin to train if (ecm != null) {
for (int iter = booster.getVersion() / 2; iter < round; iter++) { checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
}
// begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) { if (booster.getVersion() % 2 == 0) {
if (obj != null) { if (obj != null) {
booster.update(dtrain, obj); booster.update(dtrain, obj);
} else { } else {
booster.update(dtrain, iter); booster.update(dtrain, iter);
} }
saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint(); booster.saveRabitCheckpoint();
} }
@@ -239,6 +252,44 @@ public class XGBoost {
return booster; return booster;
} }
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRounds if non-zero, training would be stopped
* after a specified number of consecutive
* goes to the unexpected direction in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRounds,
Booster booster) throws XGBoostError {
try {
return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
earlyStoppingRounds, booster,
-1, null, null);
} catch (IOException e) {
logger.error("training failed in xgboost4j", e);
throw new XGBoostError("training failed in xgboost4j ", e);
}
}
private static Integer tryGetIntFromObject(Object o) { private static Integer tryGetIntFromObject(Object o) {
if (o instanceof Integer) { if (o instanceof Integer) {
return (int)o; return (int)o;

View File

@@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
public XGBoostError(String message) { public XGBoostError(String message) {
super(message); super(message);
} }
public XGBoostError(String message, Throwable cause) {
super(message, cause);
}
} }

View File

@@ -0,0 +1,37 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
import org.apache.hadoop.fs.FileSystem
class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
extends JavaECM(checkpointPath, fs) {
def updateCheckpoint(booster: Booster): Unit = {
super.updateCheckpoint(booster.booster)
}
def loadCheckpointAsScalaBooster(): Booster = {
val loadedBooster = super.loadCheckpointAsBooster()
if (loadedBooster == null) {
null
} else {
new Booster(loadedBooster)
}
}
}

View File

@@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
import java.io.InputStream import java.io.InputStream
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError} import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
/** /**
* XGBoost Scala Training function. * XGBoost Scala Training function.
*/ */
object XGBoost { object XGBoost {
private[scala] def trainAndSaveCheckpoint(
dtrain: DMatrix,
params: Map[String, Any],
numRounds: Int,
watches: Map[String, DMatrix] = Map(),
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
earlyStoppingRound: Int = 0,
prevBooster: Booster,
checkpointParams: Option[ExternalCheckpointParams]): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (prevBooster == null) {
null
} else {
prevBooster.booster
}
val xgboostInJava = checkpointParams.
map(cp => {
JXGBoost.trainAndSaveCheckpoint(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
cp.checkpointInterval,
cp.checkpointPath,
new Path(cp.checkpointPath).getFileSystem(new Configuration()))
}).
getOrElse(
JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
)
if (prevBooster == null) {
new Booster(xgboostInJava)
} else {
// Avoid creating a new SBooster with the same JBooster
prevBooster
}
}
/** /**
* Train a booster given parameters. * Train a booster given parameters.
* *
@@ -55,23 +101,8 @@ object XGBoost {
eval: EvalTrait = null, eval: EvalTrait = null,
earlyStoppingRound: Int = 0, earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = { booster: Booster = null): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
val jBooster = if (booster == null) { booster, None)
null
} else {
booster.booster
}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
if (booster == null) {
new Booster(xgboostInJava)
} else {
// Avoid creating a new SBooster with the same JBooster
booster
}
} }
/** /**
@@ -126,3 +157,41 @@ object XGBoost {
new Booster(xgboostInJava) new Booster(xgboostInJava)
} }
} }
private[scala] case class ExternalCheckpointParams(
checkpointInterval: Int,
checkpointPath: String,
skipCleanCheckpoint: Boolean)
private[scala] object ExternalCheckpointParams {
def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None | Some(null) | Some("") => null
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
s" an instance of String, but current value is ${params("checkpoint_path")}")
}
val checkpointInterval: Int = params.get("checkpoint_interval") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
case None => false
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
if (checkpointPath == null || checkpointInterval == 0) {
None
} else {
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
}
}
}

View File

@@ -62,6 +62,7 @@ setup(name='xgboost',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7'], 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8'],
python_requires='>=3.5', python_requires='>=3.5',
url='https://github.com/dmlc/xgboost') url='https://github.com/dmlc/xgboost')

View File

@@ -79,6 +79,7 @@ setup(name='xgboost',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7'], 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8'],
python_requires='>=3.5', python_requires='>=3.5',
url='https://github.com/dmlc/xgboost') url='https://github.com/dmlc/xgboost')

View File

@@ -1 +1 @@
1.0.0-SNAPSHOT 1.0.1

View File

@@ -5,6 +5,8 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
""" """
import os import os
import sys
import warnings
from .core import DMatrix, Booster from .core import DMatrix, Booster
from .training import train, cv from .training import train, cv
@@ -19,6 +21,12 @@ try:
except ImportError: except ImportError:
pass pass
if sys.version_info[:2] == (3, 5):
warnings.warn(
'Python 3.5 support is deprecated; XGBoost will require Python 3.6+ in the near future. ' +
'Consider upgrading to Python 3.6+.',
FutureWarning)
VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION') VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
with open(VERSION_FILE) as f: with open(VERSION_FILE) as f:
__version__ = f.read().strip() __version__ = f.read().strip()

View File

@@ -600,6 +600,7 @@ class DaskXGBRegressor(DaskScikitLearnBase):
results = train(self.client, params, dtrain, results = train(self.client, params, dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals) evals=evals)
# pylint: disable=attribute-defined-outside-init
self._Booster = results['booster'] self._Booster = results['booster']
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history'] self.evals_result_ = results['history']

View File

@@ -200,7 +200,7 @@ Parameters
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", @xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model', 'objective']) ['estimators', 'model', 'objective'])
class XGBModel(XGBModelBase): class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name, missing-docstring # pylint: disable=too-many-arguments, too-many-instance-attributes, missing-docstring
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100, def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
verbosity=None, objective=None, booster=None, verbosity=None, objective=None, booster=None,
tree_method=None, n_jobs=None, gamma=None, tree_method=None, n_jobs=None, gamma=None,
@@ -210,7 +210,8 @@ class XGBModel(XGBModelBase):
scale_pos_weight=None, base_score=None, random_state=None, scale_pos_weight=None, base_score=None, random_state=None,
missing=None, num_parallel_tree=None, missing=None, num_parallel_tree=None,
monotone_constraints=None, interaction_constraints=None, monotone_constraints=None, interaction_constraints=None,
importance_type="gain", gpu_id=None, **kwargs): importance_type="gain", gpu_id=None,
validate_parameters=False, **kwargs):
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise XGBoostError( raise XGBoostError(
'sklearn needs to be installed in order to use this module') 'sklearn needs to be installed in order to use this module')
@@ -243,6 +244,10 @@ class XGBModel(XGBModelBase):
self.interaction_constraints = interaction_constraints self.interaction_constraints = interaction_constraints
self.importance_type = importance_type self.importance_type = importance_type
self.gpu_id = gpu_id self.gpu_id = gpu_id
# Parameter validation is not working with Scikit-Learn interface, as
# it passes all paraemters into XGBoost core, whether they are used or
# not.
self.validate_parameters = validate_parameters
def __setstate__(self, state): def __setstate__(self, state):
# backward compatibility code # backward compatibility code
@@ -314,11 +319,35 @@ class XGBModel(XGBModelBase):
if isinstance(params['random_state'], np.random.RandomState): if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint( params['random_state'] = params['random_state'].randint(
np.iinfo(np.int32).max) np.iinfo(np.int32).max)
# Parameter validation is not working with Scikit-Learn interface, as
# it passes all paraemters into XGBoost core, whether they are used or def parse_parameter(value):
# not. for t in (int, float):
if 'validate_parameters' not in params.keys(): try:
params['validate_parameters'] = False ret = t(value)
return ret
except ValueError:
continue
return None
# Get internal parameter values
try:
config = json.loads(self.get_booster().save_config())
stack = [config]
internal = {}
while stack:
obj = stack.pop()
for k, v in obj.items():
if k.endswith('_param'):
for p_k, p_v in v.items():
internal[p_k] = p_v
elif isinstance(v, dict):
stack.append(v)
for k, v in internal.items():
if k in params.keys() and params[k] is None:
params[k] = parse_parameter(v)
except XGBoostError:
pass
return params return params
def get_xgb_params(self): def get_xgb_params(self):
@@ -405,8 +434,8 @@ class XGBModel(XGBModelBase):
self.classes_ = np.array(v) self.classes_ = np.array(v)
continue continue
if k == 'type' and type(self).__name__ != v: if k == 'type' and type(self).__name__ != v:
msg = f'Current model type: {type(self).__name__}, ' + \ msg = 'Current model type: {}, '.format(type(self).__name__) + \
f'type of model in file: {v}' 'type of model in file: {}'.format(v)
raise TypeError(msg) raise TypeError(msg)
if k == 'type': if k == 'type':
continue continue

View File

@@ -2,6 +2,7 @@
* Copyright (c) by Contributors 2019 * Copyright (c) by Contributors 2019
*/ */
#include <cctype> #include <cctype>
#include <locale>
#include <sstream> #include <sstream>
#include <limits> #include <limits>
#include <cmath> #include <cmath>
@@ -24,7 +25,7 @@ void JsonWriter::Visit(JsonArray const* arr) {
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
auto const& value = vec[i]; auto const& value = vec[i];
this->Save(value); this->Save(value);
if (i != size-1) { Write(", "); } if (i != size-1) { Write(","); }
} }
this->Write("]"); this->Write("]");
} }
@@ -38,7 +39,7 @@ void JsonWriter::Visit(JsonObject const* obj) {
size_t size = obj->getObject().size(); size_t size = obj->getObject().size();
for (auto& value : obj->getObject()) { for (auto& value : obj->getObject()) {
this->Write("\"" + value.first + "\": "); this->Write("\"" + value.first + "\":");
this->Save(value.second); this->Save(value.second);
if (i != size-1) { if (i != size-1) {
@@ -692,47 +693,23 @@ Json JsonReader::ParseBoolean() {
return Json{JsonBoolean{result}}; return Json{JsonBoolean{result}};
} }
// This is an ad-hoc solution for writing numeric value in standard way. We need to add
// a locale independent way of writing stream like `std::{from, to}_chars' from C++-17.
// FIXME(trivialfis): Remove this.
class GlobalCLocale {
std::locale ori_;
public:
GlobalCLocale() : ori_{std::locale()} {
std::string const name {"C"};
try {
std::locale::global(std::locale(name.c_str()));
} catch (std::runtime_error const& e) {
LOG(FATAL) << "Failed to set locale: " << name;
}
}
~GlobalCLocale() {
std::locale::global(ori_);
}
};
Json Json::Load(StringView str) { Json Json::Load(StringView str) {
GlobalCLocale guard;
JsonReader reader(str); JsonReader reader(str);
Json json{reader.Load()}; Json json{reader.Load()};
return json; return json;
} }
Json Json::Load(JsonReader* reader) { Json Json::Load(JsonReader* reader) {
GlobalCLocale guard;
Json json{reader->Load()}; Json json{reader->Load()};
return json; return json;
} }
void Json::Dump(Json json, std::ostream *stream, bool pretty) { void Json::Dump(Json json, std::ostream *stream, bool pretty) {
GlobalCLocale guard;
JsonWriter writer(stream, pretty); JsonWriter writer(stream, pretty);
writer.Save(json); writer.Save(json);
} }
void Json::Dump(Json json, std::string* str, bool pretty) { void Json::Dump(Json json, std::string* str, bool pretty) {
GlobalCLocale guard;
std::stringstream ss; std::stringstream ss;
JsonWriter writer(&ss, pretty); JsonWriter writer(&ss, pretty);
writer.Save(json); writer.Save(json);

View File

@@ -15,12 +15,23 @@
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
#if defined(XGBOOST_STRICT_R_MODE)
#define OBSERVER_PRINT LOG(INFO)
#define OBSERVER_ENDL ""
#define OBSERVER_NEWLINE ""
#else
#define OBSERVER_PRINT std::cout
#define OBSERVER_ENDL std::endl
#define OBSERVER_NEWLINE "\n"
#endif // defined(XGBOOST_STRICT_R_MODE)
namespace xgboost { namespace xgboost {
/*\brief An observer for logging internal data structures. /*\brief An observer for logging internal data structures.
* *
* This class is designed to be `diff` tool friendly, which means it uses plain * This class is designed to be `diff` tool friendly, which means it uses plain
* `std::cout` for printing to avoid the time information emitted by `LOG(DEBUG)` or * `std::cout` for printing to avoid the time information emitted by `LOG(DEBUG)` or
* similiar facilities. * similiar facilities. Exception: use `LOG(INFO)` for the R package, to comply
* with CRAN policy.
*/ */
class TrainingObserver { class TrainingObserver {
#if defined(XGBOOST_USE_DEBUG_OUTPUT) #if defined(XGBOOST_USE_DEBUG_OUTPUT)
@@ -32,17 +43,17 @@ class TrainingObserver {
public: public:
void Update(int32_t iter) const { void Update(int32_t iter) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; } if (XGBOOST_EXPECT(!observe_, true)) { return; }
std::cout << "Iter: " << iter << std::endl; OBSERVER_PRINT << "Iter: " << iter << OBSERVER_ENDL;
} }
/*\brief Observe tree. */ /*\brief Observe tree. */
void Observe(RegTree const& tree) { void Observe(RegTree const& tree) {
if (XGBOOST_EXPECT(!observe_, true)) { return; } if (XGBOOST_EXPECT(!observe_, true)) { return; }
std::cout << "Tree:" << std::endl; OBSERVER_PRINT << "Tree:" << OBSERVER_ENDL;
Json j_tree {Object()}; Json j_tree {Object()};
tree.SaveModel(&j_tree); tree.SaveModel(&j_tree);
std::string str; std::string str;
Json::Dump(j_tree, &str, true); Json::Dump(j_tree, &str, true);
std::cout << str << std::endl; OBSERVER_PRINT << str << OBSERVER_ENDL;
} }
/*\brief Observe tree. */ /*\brief Observe tree. */
void Observe(RegTree const* p_tree) { void Observe(RegTree const* p_tree) {
@@ -54,15 +65,15 @@ class TrainingObserver {
template <typename T> template <typename T>
void Observe(std::vector<T> const& h_vec, std::string name) const { void Observe(std::vector<T> const& h_vec, std::string name) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; } if (XGBOOST_EXPECT(!observe_, true)) { return; }
std::cout << "Procedure: " << name << std::endl; OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL;
for (size_t i = 0; i < h_vec.size(); ++i) { for (size_t i = 0; i < h_vec.size(); ++i) {
std::cout << h_vec[i] << ", "; OBSERVER_PRINT << h_vec[i] << ", ";
if (i % 8 == 0) { if (i % 8 == 0) {
std::cout << '\n'; OBSERVER_PRINT << OBSERVER_NEWLINE;
} }
} }
std::cout << std::endl; OBSERVER_PRINT << OBSERVER_ENDL;
} }
/*\brief Observe data hosted by `HostDeviceVector'. */ /*\brief Observe data hosted by `HostDeviceVector'. */
template <typename T> template <typename T>
@@ -85,16 +96,16 @@ class TrainingObserver {
if (XGBOOST_EXPECT(!observe_, true)) { return; } if (XGBOOST_EXPECT(!observe_, true)) { return; }
Json obj {toJson(p)}; Json obj {toJson(p)};
std::cout << "Parameter: " << name << ":\n" << obj << std::endl; OBSERVER_PRINT << "Parameter: " << name << ":\n" << obj << OBSERVER_ENDL;
} }
/*\brief Observe parameters provided by users. */ /*\brief Observe parameters provided by users. */
void Observe(Args const& args) const { void Observe(Args const& args) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; } if (XGBOOST_EXPECT(!observe_, true)) { return; }
for (auto kv : args) { for (auto kv : args) {
std::cout << kv.first << ": " << kv.second << "\n"; OBSERVER_PRINT << kv.first << ": " << kv.second << OBSERVER_NEWLINE;
} }
std::cout << std::endl; OBSERVER_PRINT << OBSERVER_ENDL;
} }
/*\brief Get a global instance. */ /*\brief Get a global instance. */

View File

@@ -89,7 +89,7 @@ void Monitor::PrintStatistics(StatMap const& statistics) const {
"Timer for " << kv.first << " did not get stopped properly."; "Timer for " << kv.first << " did not get stopped properly.";
continue; continue;
} }
std::cout << kv.first << ": " << static_cast<double>(kv.second.second) / 1e+6 LOG(CONSOLE) << kv.first << ": " << static_cast<double>(kv.second.second) / 1e+6
<< "s, " << kv.second.first << " calls @ " << "s, " << kv.second.first << " calls @ "
<< kv.second.second << kv.second.second
<< "us" << std::endl; << "us" << std::endl;
@@ -107,10 +107,9 @@ void Monitor::Print() const {
if (rabit::GetRank() == 0) { if (rabit::GetRank() == 0) {
LOG(CONSOLE) << "======== Monitor: " << label << " ========"; LOG(CONSOLE) << "======== Monitor: " << label << " ========";
for (size_t i = 0; i < world.size(); ++i) { for (size_t i = 0; i < world.size(); ++i) {
std::cout << "From rank: " << i << ": " << std::endl; LOG(CONSOLE) << "From rank: " << i << ": " << std::endl;
auto const& statistic = world[i]; auto const& statistic = world[i];
this->PrintStatistics(statistic); this->PrintStatistics(statistic);
std::cout << std::endl;
} }
} }
} else { } else {
@@ -123,7 +122,6 @@ void Monitor::Print() const {
LOG(CONSOLE) << "======== Monitor: " << label << " ========"; LOG(CONSOLE) << "======== Monitor: " << label << " ========";
this->PrintStatistics(stat_map); this->PrintStatistics(stat_map);
} }
std::cout << std::endl;
} }
} // namespace common } // namespace common

View File

@@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2019 by Contributors * Copyright 2014-2020 by Contributors
* \file learner.cc * \file learner.cc
* \brief Implementation of learning algorithm. * \brief Implementation of learning algorithm.
* \author Tianqi Chen * \author Tianqi Chen
@@ -67,19 +67,26 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
/* \brief global bias */ /* \brief global bias */
bst_float base_score; bst_float base_score;
/* \brief number of features */ /* \brief number of features */
unsigned num_feature; uint32_t num_feature;
/* \brief number of classes, if it is multi-class classification */ /* \brief number of classes, if it is multi-class classification */
int num_class; int32_t num_class;
/*! \brief Model contain additional properties */ /*! \brief Model contain additional properties */
int contain_extra_attrs; int32_t contain_extra_attrs;
/*! \brief Model contain eval metrics */ /*! \brief Model contain eval metrics */
int contain_eval_metrics; int32_t contain_eval_metrics;
/*! \brief the version of XGBoost. */
uint32_t major_version;
uint32_t minor_version;
/*! \brief reserved field */ /*! \brief reserved field */
int reserved[29]; int reserved[27];
/*! \brief constructor */ /*! \brief constructor */
LearnerModelParamLegacy() { LearnerModelParamLegacy() {
std::memset(this, 0, sizeof(LearnerModelParamLegacy)); std::memset(this, 0, sizeof(LearnerModelParamLegacy));
base_score = 0.5f; base_score = 0.5f;
major_version = std::get<0>(Version::Self());
minor_version = std::get<1>(Version::Self());
static_assert(sizeof(LearnerModelParamLegacy) == 136,
"Do not change the size of this struct, as it will break binary IO.");
} }
// Skip other legacy fields. // Skip other legacy fields.
Json ToJson() const { Json ToJson() const {
@@ -118,7 +125,8 @@ LearnerModelParam::LearnerModelParam(
: base_score{base_margin}, num_feature{user_param.num_feature}, : base_score{base_margin}, num_feature{user_param.num_feature},
num_output_group{user_param.num_class == 0 num_output_group{user_param.num_class == 0
? 1 ? 1
: static_cast<uint32_t>(user_param.num_class)} {} : static_cast<uint32_t>(user_param.num_class)}
{}
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> { struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none. // data split mode, can be row, col, or none.
@@ -140,7 +148,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
.describe("Data split mode for distributed training."); .describe("Data split mode for distributed training.");
DMLC_DECLARE_FIELD(disable_default_eval_metric) DMLC_DECLARE_FIELD(disable_default_eval_metric)
.set_default(0) .set_default(0)
.describe("flag to disable default metric. Set to >0 to disable"); .describe("Flag to disable default metric. Set to >0 to disable");
DMLC_DECLARE_FIELD(booster) DMLC_DECLARE_FIELD(booster)
.set_default("gbtree") .set_default("gbtree")
.describe("Gradient booster used for training."); .describe("Gradient booster used for training.");
@@ -200,6 +208,7 @@ class LearnerImpl : public Learner {
Args args = {cfg_.cbegin(), cfg_.cend()}; Args args = {cfg_.cbegin(), cfg_.cend()};
tparam_.UpdateAllowUnknown(args); tparam_.UpdateAllowUnknown(args);
auto mparam_backup = mparam_;
mparam_.UpdateAllowUnknown(args); mparam_.UpdateAllowUnknown(args);
generic_parameters_.UpdateAllowUnknown(args); generic_parameters_.UpdateAllowUnknown(args);
generic_parameters_.CheckDeprecated(); generic_parameters_.CheckDeprecated();
@@ -217,17 +226,33 @@ class LearnerImpl : public Learner {
// set seed only before the model is initialized // set seed only before the model is initialized
common::GlobalRandom().seed(generic_parameters_.seed); common::GlobalRandom().seed(generic_parameters_.seed);
// must precede configure gbm since num_features is required for gbm // must precede configure gbm since num_features is required for gbm
this->ConfigureNumFeatures(); this->ConfigureNumFeatures();
args = {cfg_.cbegin(), cfg_.cend()}; // renew args = {cfg_.cbegin(), cfg_.cend()}; // renew
this->ConfigureObjective(old_tparam, &args); this->ConfigureObjective(old_tparam, &args);
this->ConfigureGBM(old_tparam, args);
this->ConfigureMetrics(args);
generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU());
// Before 1.0.0, we save `base_score` into binary as a transformed value by objective.
// After 1.0.0 we save the value provided by user and keep it immutable instead. To
// keep the stability, we initialize it in binary LoadModel instead of configuration.
// Under what condition should we omit the transformation:
//
// - base_score is loaded from old binary model.
//
// What are the other possible conditions:
//
// - model loaded from new binary or JSON.
// - model is created from scratch.
// - model is configured second time due to change of parameter
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
learner_model_param_ = LearnerModelParam(mparam_, learner_model_param_ = LearnerModelParam(mparam_,
obj_->ProbToMargin(mparam_.base_score)); obj_->ProbToMargin(mparam_.base_score));
}
this->ConfigureGBM(old_tparam, args);
generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU());
this->ConfigureMetrics(args);
this->need_configuration_ = false; this->need_configuration_ = false;
if (generic_parameters_.validate_parameters) { if (generic_parameters_.validate_parameters) {
@@ -269,10 +294,7 @@ class LearnerImpl : public Learner {
} }
} }
} }
auto learner_model_param = mparam_.ToJson();
for (auto const& kv : get<Object>(learner_model_param)) {
keys.emplace_back(kv.first);
}
keys.emplace_back(kEvalMetric); keys.emplace_back(kEvalMetric);
keys.emplace_back("verbosity"); keys.emplace_back("verbosity");
keys.emplace_back("num_output_group"); keys.emplace_back("num_output_group");
@@ -340,9 +362,6 @@ class LearnerImpl : public Learner {
cache_)); cache_));
gbm_->LoadModel(gradient_booster); gbm_->LoadModel(gradient_booster);
learner_model_param_ = LearnerModelParam(mparam_,
obj_->ProbToMargin(mparam_.base_score));
auto const& j_attributes = get<Object const>(learner.at("attributes")); auto const& j_attributes = get<Object const>(learner.at("attributes"));
attributes_.clear(); attributes_.clear();
for (auto const& kv : j_attributes) { for (auto const& kv : j_attributes) {
@@ -425,6 +444,7 @@ class LearnerImpl : public Learner {
auto& learner_parameters = out["learner"]; auto& learner_parameters = out["learner"];
learner_parameters["learner_train_param"] = toJson(tparam_); learner_parameters["learner_train_param"] = toJson(tparam_);
learner_parameters["learner_model_param"] = mparam_.ToJson();
learner_parameters["gradient_booster"] = Object(); learner_parameters["gradient_booster"] = Object();
auto& gradient_booster = learner_parameters["gradient_booster"]; auto& gradient_booster = learner_parameters["gradient_booster"];
gbm_->SaveConfig(&gradient_booster); gbm_->SaveConfig(&gradient_booster);
@@ -461,6 +481,7 @@ class LearnerImpl : public Learner {
} }
if (header[0] == '{') { if (header[0] == '{') {
// Dispatch to JSON
auto json_stream = common::FixedSizeStream(&fp); auto json_stream = common::FixedSizeStream(&fp);
std::string buffer; std::string buffer;
json_stream.Take(&buffer); json_stream.Take(&buffer);
@@ -473,25 +494,10 @@ class LearnerImpl : public Learner {
// read parameter // read parameter
CHECK_EQ(fi->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_)) CHECK_EQ(fi->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
<< "BoostLearner: wrong model format"; << "BoostLearner: wrong model format";
{
// backward compatibility code for compatible with old model type CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format";
// for new model, Read(&name_obj_) is suffice
uint64_t len;
CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len));
if (len >= std::numeric_limits<unsigned>::max()) {
int gap;
CHECK_EQ(fi->Read(&gap, sizeof(gap)), sizeof(gap))
<< "BoostLearner: wrong model format";
len = len >> static_cast<uint64_t>(32UL);
}
if (len != 0) {
tparam_.objective.resize(len);
CHECK_EQ(fi->Read(&tparam_.objective[0], len), len)
<< "BoostLearner: wrong model format";
}
}
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format"; CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
// duplicated code with LazyInitModel
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_)); obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_, gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_,
&learner_model_param_, cache_)); &learner_model_param_, cache_));
@@ -510,34 +516,57 @@ class LearnerImpl : public Learner {
} }
attributes_ = std::map<std::string, std::string>(attr.begin(), attr.end()); attributes_ = std::map<std::string, std::string>(attr.begin(), attr.end());
} }
if (tparam_.objective == "count:poisson") { bool warn_old_model { false };
std::string max_delta_step; if (attributes_.find("count_poisson_max_delta_step") != attributes_.cend()) {
fi->Read(&max_delta_step); // Loading model from < 1.0.0, objective is not saved.
cfg_["max_delta_step"] = max_delta_step; cfg_["max_delta_step"] = attributes_.at("count_poisson_max_delta_step");
attributes_.erase("count_poisson_max_delta_step");
warn_old_model = true;
} else {
warn_old_model = false;
} }
if (mparam_.contain_eval_metrics != 0) {
std::vector<std::string> metr; if (mparam_.major_version >= 1) {
fi->Read(&metr); learner_model_param_ = LearnerModelParam(mparam_,
for (auto name : metr) { obj_->ProbToMargin(mparam_.base_score));
metrics_.emplace_back(Metric::Create(name, &generic_parameters_)); } else {
// Before 1.0.0, base_score is saved as a transformed value, and there's no version
// attribute in the saved model.
learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score);
warn_old_model = true;
}
if (attributes_.find("objective") != attributes_.cend()) {
auto obj_str = attributes_.at("objective");
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
obj_->LoadConfig(j_obj);
attributes_.erase("objective");
} else {
warn_old_model = true;
}
if (attributes_.find("metrics") != attributes_.cend()) {
auto metrics_str = attributes_.at("metrics");
std::vector<std::string> names { common::Split(metrics_str, ';') };
attributes_.erase("metrics");
for (auto const& n : names) {
this->SetParam(kEvalMetric, n);
} }
} }
if (warn_old_model) {
LOG(WARNING) << "Loading model from XGBoost < 1.0.0, consider saving it "
"again for improved compatibility";
}
// Renew the version.
mparam_.major_version = std::get<0>(Version::Self());
mparam_.minor_version = std::get<1>(Version::Self());
cfg_["num_class"] = common::ToString(mparam_.num_class); cfg_["num_class"] = common::ToString(mparam_.num_class);
cfg_["num_feature"] = common::ToString(mparam_.num_feature); cfg_["num_feature"] = common::ToString(mparam_.num_feature);
auto n = tparam_.__DICT__(); auto n = tparam_.__DICT__();
cfg_.insert(n.cbegin(), n.cend()); cfg_.insert(n.cbegin(), n.cend());
Args args = {cfg_.cbegin(), cfg_.cend()};
generic_parameters_.UpdateAllowUnknown(args);
gbm_->Configure(args);
obj_->Configure({cfg_.begin(), cfg_.end()});
for (auto& p_metric : metrics_) {
p_metric->Configure({cfg_.begin(), cfg_.end()});
}
// copy dsplit from config since it will not run again during restore // copy dsplit from config since it will not run again during restore
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) { if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
tparam_.dsplit = DataSplitMode::kRow; tparam_.dsplit = DataSplitMode::kRow;
@@ -554,15 +583,8 @@ class LearnerImpl : public Learner {
void SaveModel(dmlc::Stream* fo) const override { void SaveModel(dmlc::Stream* fo) const override {
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
std::vector<std::pair<std::string, std::string> > extra_attr; std::vector<std::pair<std::string, std::string> > extra_attr;
// extra attributed to be added just before saving
if (tparam_.objective == "count:poisson") {
auto it = cfg_.find("max_delta_step");
if (it != cfg_.end()) {
// write `max_delta_step` parameter as extra attribute of booster
mparam.contain_extra_attrs = 1; mparam.contain_extra_attrs = 1;
extra_attr.emplace_back("count_poisson_max_delta_step", it->second);
}
}
{ {
std::vector<std::string> saved_params; std::vector<std::string> saved_params;
// check if rabit_bootstrap_cache were set to non zero before adding to checkpoint // check if rabit_bootstrap_cache were set to non zero before adding to checkpoint
@@ -579,6 +601,24 @@ class LearnerImpl : public Learner {
} }
} }
} }
{
// Similar to JSON model IO, we save the objective.
Json j_obj { Object() };
obj_->SaveConfig(&j_obj);
std::string obj_doc;
Json::Dump(j_obj, &obj_doc);
extra_attr.emplace_back("objective", obj_doc);
}
// As of 1.0.0, JVM Package and R Package uses Save/Load model for serialization.
// Remove this part once they are ported to use actual serialization methods.
if (mparam.contain_eval_metrics != 0) {
std::stringstream os;
for (auto& ev : metrics_) {
os << ev->Name() << ";";
}
extra_attr.emplace_back("metrics", os.str());
}
fo->Write(&mparam, sizeof(LearnerModelParamLegacy)); fo->Write(&mparam, sizeof(LearnerModelParamLegacy));
fo->Write(tparam_.objective); fo->Write(tparam_.objective);
fo->Write(tparam_.booster); fo->Write(tparam_.booster);
@@ -591,25 +631,6 @@ class LearnerImpl : public Learner {
fo->Write(std::vector<std::pair<std::string, std::string>>( fo->Write(std::vector<std::pair<std::string, std::string>>(
attr.begin(), attr.end())); attr.begin(), attr.end()));
} }
if (tparam_.objective == "count:poisson") {
auto it = cfg_.find("max_delta_step");
if (it != cfg_.end()) {
fo->Write(it->second);
} else {
// recover value of max_delta_step from extra attributes
auto it2 = attributes_.find("count_poisson_max_delta_step");
const std::string max_delta_step
= (it2 != attributes_.end()) ? it2->second : kMaxDeltaStepDefaultValue;
fo->Write(max_delta_step);
}
}
if (mparam.contain_eval_metrics != 0) {
std::vector<std::string> metr;
for (auto& ev : metrics_) {
metr.emplace_back(ev->Name());
}
fo->Write(metr);
}
} }
void Save(dmlc::Stream* fo) const override { void Save(dmlc::Stream* fo) const override {
@@ -663,11 +684,13 @@ class LearnerImpl : public Learner {
If you are loading a serialized model (like pickle in Python) generated by older If you are loading a serialized model (like pickle in Python) generated by older
XGBoost, please export the model by calling `Booster.save_model` from that version XGBoost, please export the model by calling `Booster.save_model` from that version
first, then load it back in current version. See: first, then load it back in current version. There's a simple script for helping
the process. See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
for more details about differences between saving model and serializing. for reference to the script, and more details about differences between saving model and
serializing.
)doc"; )doc";
int64_t sz {-1}; int64_t sz {-1};
@@ -856,7 +879,8 @@ class LearnerImpl : public Learner {
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) { void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
// Once binary IO is gone, NONE of these config is useful. // Once binary IO is gone, NONE of these config is useful.
if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0") { if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0" &&
tparam_.objective != "multi:softprob") {
cfg_["num_output_group"] = cfg_["num_class"]; cfg_["num_output_group"] = cfg_["num_class"];
if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) { if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) {
tparam_.objective = "multi:softmax"; tparam_.objective = "multi:softmax";
@@ -921,7 +945,6 @@ class LearnerImpl : public Learner {
} }
CHECK_NE(mparam_.num_feature, 0) CHECK_NE(mparam_.num_feature, 0)
<< "0 feature is supplied. Are you using raw Booster interface?"; << "0 feature is supplied. Are you using raw Booster interface?";
learner_model_param_.num_feature = mparam_.num_feature;
// Remove these once binary IO is gone. // Remove these once binary IO is gone.
cfg_["num_feature"] = common::ToString(mparam_.num_feature); cfg_["num_feature"] = common::ToString(mparam_.num_feature);
cfg_["num_class"] = common::ToString(mparam_.num_class); cfg_["num_class"] = common::ToString(mparam_.num_class);

View File

@@ -3,6 +3,7 @@ ARG CMAKE_VERSION=3.12
# Environment # Environment
ENV DEBIAN_FRONTEND noninteractive ENV DEBIAN_FRONTEND noninteractive
SHELL ["/bin/bash", "-c"] # Use Bash as shell
# Install all basic requirements # Install all basic requirements
RUN \ RUN \
@@ -19,10 +20,17 @@ ENV PATH=/opt/python/bin:$PATH
ENV GOSU_VERSION 1.10 ENV GOSU_VERSION 1.10
# Install Python packages # Create new Conda environment with Python 3.5
RUN conda create -n py35 python=3.5 && \
source activate py35 && \
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
source deactivate
# Install Python packages in default env
RUN \ RUN \
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh recommonmark guzzle_sphinx_theme mock \ pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh \
breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 jsonschema && \ recommonmark guzzle_sphinx_theme mock breathe graphviz \
pytest scikit-learn wheel kubernetes urllib3 jsonschema boto3 && \
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \ pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \
pip install "dask[complete]" pip install "dask[complete]"

View File

@@ -5,31 +5,35 @@ set -x
suite=$1 suite=$1
# Install XGBoost Python package # Install XGBoost Python package
wheel_found=0 function install_xgboost {
for file in python-package/dist/*.whl wheel_found=0
do for file in python-package/dist/*.whl
do
if [ -e "${file}" ] if [ -e "${file}" ]
then then
pip install --user "${file}" pip install --user "${file}"
wheel_found=1 wheel_found=1
break # need just one break # need just one
fi fi
done done
if [ "$wheel_found" -eq 0 ] if [ "$wheel_found" -eq 0 ]
then then
pushd . pushd .
cd python-package cd python-package
python setup.py install --user python setup.py install --user
popd popd
fi fi
}
# Run specified test suite # Run specified test suite
case "$suite" in case "$suite" in
gpu) gpu)
install_xgboost
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu
;; ;;
mgpu) mgpu)
install_xgboost
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu pytest -v -s --fulltrace -m "mgpu" tests/python-gpu
cd tests/distributed cd tests/distributed
./runtests-gpu.sh ./runtests-gpu.sh
@@ -39,17 +43,25 @@ case "$suite" in
cudf) cudf)
source activate cudf_test source activate cudf_test
install_xgboost
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py tests/python-gpu/test_from_cupy.py pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py tests/python-gpu/test_from_cupy.py
;; ;;
cpu) cpu)
install_xgboost
pytest -v -s --fulltrace tests/python pytest -v -s --fulltrace tests/python
cd tests/distributed cd tests/distributed
./runtests.sh ./runtests.sh
;; ;;
cpu-py35)
source activate py35
install_xgboost
pytest -v -s --fulltrace tests/python
;;
*) *)
echo "Usage: $0 {gpu|mgpu|cudf|cpu}" echo "Usage: $0 {gpu|mgpu|cudf|cpu|cpu-py35}"
exit 1 exit 1
;; ;;
esac esac

View File

@@ -54,7 +54,7 @@ TEST(Version, Basic) {
ptr = 0; ptr = 0;
v = std::stoi(str, &ptr); v = std::stoi(str, &ptr);
ASSERT_EQ(v, XGBOOST_VER_MINOR) << "patch: " << v;; ASSERT_EQ(v, XGBOOST_VER_PATCH) << "patch: " << v;;
str = str.substr(ptr); str = str.substr(ptr);
ASSERT_EQ(str.size(), 0); ASSERT_EQ(str.size(), 0);

View File

@@ -180,6 +180,41 @@ TEST(Learner, JsonModelIO) {
delete pp_dmat; delete pp_dmat;
} }
TEST(Learner, BinaryModelIO) {
size_t constexpr kRows = 8;
int32_t constexpr kIters = 4;
auto pp_dmat = CreateDMatrix(kRows, 10, 0);
std::shared_ptr<DMatrix> p_dmat {*pp_dmat};
p_dmat->Info().labels_.Resize(kRows);
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
learner->SetParam("eval_metric", "rmsle");
learner->Configure();
for (int32_t iter = 0; iter < kIters; ++iter) {
learner->UpdateOneIter(iter, p_dmat.get());
}
dmlc::TemporaryDirectory tempdir;
std::string const fname = tempdir.path + "binary_model_io.bin";
{
// Make sure the write is complete before loading.
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
learner->SaveModel(fo.get());
}
learner.reset(Learner::Create({p_dmat}));
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
learner->LoadModel(fi.get());
learner->Configure();
Json config { Object() };
learner->SaveConfig(&config);
std::string config_str;
Json::Dump(config, &config_str);
ASSERT_NE(config_str.find("rmsle"), std::string::npos);
ASSERT_EQ(config_str.find("WARNING"), std::string::npos);
delete pp_dmat;
}
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
// Tests for automatic GPU configuration. // Tests for automatic GPU configuration.
TEST(Learner, GPUConfiguration) { TEST(Learner, GPUConfiguration) {

View File

@@ -0,0 +1,148 @@
import xgboost
import numpy as np
import os
kRounds = 2
kRows = 1000
kCols = 4
kForests = 2
kMaxDepth = 2
kClasses = 3
X = np.random.randn(kRows, kCols)
w = np.random.uniform(size=kRows)
version = xgboost.__version__
np.random.seed(1994)
target_dir = 'models'
def booster_bin(model):
return os.path.join(target_dir,
'xgboost-' + version + '.' + model + '.bin')
def booster_json(model):
return os.path.join(target_dir,
'xgboost-' + version + '.' + model + '.json')
def skl_bin(model):
return os.path.join(target_dir,
'xgboost_scikit-' + version + '.' + model + '.bin')
def skl_json(model):
return os.path.join(target_dir,
'xgboost_scikit-' + version + '.' + model + '.json')
def generate_regression_model():
print('Regression')
y = np.random.randn(kRows)
data = xgboost.DMatrix(X, label=y, weight=w)
booster = xgboost.train({'tree_method': 'hist',
'num_parallel_tree': kForests,
'max_depth': kMaxDepth},
num_boost_round=kRounds, dtrain=data)
booster.save_model(booster_bin('reg'))
booster.save_model(booster_json('reg'))
reg = xgboost.XGBRegressor(tree_method='hist',
num_parallel_tree=kForests,
max_depth=kMaxDepth,
n_estimators=kRounds)
reg.fit(X, y, w)
reg.save_model(skl_bin('reg'))
reg.save_model(skl_json('reg'))
def generate_logistic_model():
print('Logistic')
y = np.random.randint(0, 2, size=kRows)
assert y.max() == 1 and y.min() == 0
data = xgboost.DMatrix(X, label=y, weight=w)
booster = xgboost.train({'tree_method': 'hist',
'num_parallel_tree': kForests,
'max_depth': kMaxDepth,
'objective': 'binary:logistic'},
num_boost_round=kRounds, dtrain=data)
booster.save_model(booster_bin('logit'))
booster.save_model(booster_json('logit'))
reg = xgboost.XGBClassifier(tree_method='hist',
num_parallel_tree=kForests,
max_depth=kMaxDepth,
n_estimators=kRounds)
reg.fit(X, y, w)
reg.save_model(skl_bin('logit'))
reg.save_model(skl_json('logit'))
def generate_classification_model():
print('Classification')
y = np.random.randint(0, kClasses, size=kRows)
data = xgboost.DMatrix(X, label=y, weight=w)
booster = xgboost.train({'num_class': kClasses,
'tree_method': 'hist',
'num_parallel_tree': kForests,
'max_depth': kMaxDepth},
num_boost_round=kRounds, dtrain=data)
booster.save_model(booster_bin('cls'))
booster.save_model(booster_json('cls'))
cls = xgboost.XGBClassifier(tree_method='hist',
num_parallel_tree=kForests,
max_depth=kMaxDepth,
n_estimators=kRounds)
cls.fit(X, y, w)
cls.save_model(skl_bin('cls'))
cls.save_model(skl_json('cls'))
def generate_ranking_model():
print('Learning to Rank')
y = np.random.randint(5, size=kRows)
w = np.random.uniform(size=20)
g = np.repeat(50, 20)
data = xgboost.DMatrix(X, y, weight=w)
data.set_group(g)
booster = xgboost.train({'objective': 'rank:ndcg',
'num_parallel_tree': kForests,
'tree_method': 'hist',
'max_depth': kMaxDepth},
num_boost_round=kRounds,
dtrain=data)
booster.save_model(booster_bin('ltr'))
booster.save_model(booster_json('ltr'))
ranker = xgboost.sklearn.XGBRanker(n_estimators=kRounds,
tree_method='hist',
objective='rank:ndcg',
max_depth=kMaxDepth,
num_parallel_tree=kForests)
ranker.fit(X, y, g, sample_weight=w)
ranker.save_model(skl_bin('ltr'))
ranker.save_model(skl_json('ltr'))
def write_versions():
versions = {'numpy': np.__version__,
'xgboost': version}
with open(os.path.join(target_dir, 'version'), 'w') as fd:
fd.write(str(versions))
if __name__ == '__main__':
if not os.path.exists(target_dir):
os.mkdir(target_dir)
generate_regression_model()
generate_logistic_model()
generate_classification_model()
generate_ranking_model()
write_versions()

View File

@@ -39,7 +39,7 @@ class TestBasic(unittest.TestCase):
def test_basic(self): def test_basic(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0, param = {'max_depth': 2, 'eta': 1,
'objective': 'binary:logistic'} 'objective': 'binary:logistic'}
# specify validations set to watch performance # specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]

View File

@@ -5,6 +5,7 @@ import os
import json import json
import testing as tm import testing as tm
import pytest import pytest
import locale
dpath = 'demo/data/' dpath = 'demo/data/'
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
@@ -284,25 +285,42 @@ class TestModels(unittest.TestCase):
self.assertRaises(ValueError, bst.predict, dm1) self.assertRaises(ValueError, bst.predict, dm1)
bst.predict(dm2) # success bst.predict(dm2) # success
def test_model_binary_io(self):
model_path = 'test_model_binary_io.bin'
parameters = {'tree_method': 'hist', 'booster': 'gbtree',
'scale_pos_weight': '0.5'}
X = np.random.random((10, 3))
y = np.random.random((10,))
dtrain = xgb.DMatrix(X, y)
bst = xgb.train(parameters, dtrain, num_boost_round=2)
bst.save_model(model_path)
bst = xgb.Booster(model_file=model_path)
os.remove(model_path)
config = json.loads(bst.save_config())
assert float(config['learner']['objective'][
'reg_loss_param']['scale_pos_weight']) == 0.5
def test_model_json_io(self): def test_model_json_io(self):
model_path = './model.json' loc = locale.getpreferredencoding(False)
model_path = 'test_model_json_io.json'
parameters = {'tree_method': 'hist', 'booster': 'gbtree'} parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
j_model = json_model(model_path, parameters) j_model = json_model(model_path, parameters)
assert isinstance(j_model['learner'], dict) assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file='./model.json') bst = xgb.Booster(model_file=model_path)
bst.save_model(fname=model_path) bst.save_model(fname=model_path)
with open('./model.json', 'r') as fd: with open(model_path, 'r') as fd:
j_model = json.load(fd) j_model = json.load(fd)
assert isinstance(j_model['learner'], dict) assert isinstance(j_model['learner'], dict)
os.remove(model_path) os.remove(model_path)
assert locale.getpreferredencoding(False) == loc
@pytest.mark.skipif(**tm.no_json_schema()) @pytest.mark.skipif(**tm.no_json_schema())
def test_json_schema(self): def test_json_schema(self):
import jsonschema import jsonschema
model_path = './model.json' model_path = 'test_json_schema.json'
path = os.path.dirname( path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
doc = os.path.join(path, 'doc', 'model.schema') doc = os.path.join(path, 'doc', 'model.schema')

View File

@@ -0,0 +1,130 @@
import xgboost
import os
import generate_models as gm
import json
import zipfile
import pytest
def run_model_param_check(config):
assert config['learner']['learner_model_param']['num_feature'] == str(4)
assert config['learner']['learner_train_param']['booster'] == 'gbtree'
def run_booster_check(booster, name):
config = json.loads(booster.save_config())
run_model_param_check(config)
if name.find('cls') != -1:
assert (len(booster.get_dump()) == gm.kForests * gm.kRounds *
gm.kClasses)
assert float(
config['learner']['learner_model_param']['base_score']) == 0.5
assert config['learner']['learner_train_param'][
'objective'] == 'multi:softmax'
elif name.find('logit') != -1:
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert config['learner']['learner_model_param']['num_class'] == str(0)
assert config['learner']['learner_train_param'][
'objective'] == 'binary:logistic'
elif name.find('ltr') != -1:
assert config['learner']['learner_train_param'][
'objective'] == 'rank:ndcg'
else:
assert name.find('reg') != -1
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert float(
config['learner']['learner_model_param']['base_score']) == 0.5
assert config['learner']['learner_train_param'][
'objective'] == 'reg:squarederror'
def run_scikit_model_check(name, path):
if name.find('reg') != -1:
reg = xgboost.XGBRegressor()
reg.load_model(path)
config = json.loads(reg.get_booster().save_config())
if name.find('0.90') != -1:
assert config['learner']['learner_train_param'][
'objective'] == 'reg:linear'
else:
assert config['learner']['learner_train_param'][
'objective'] == 'reg:squarederror'
assert (len(reg.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
run_model_param_check(config)
elif name.find('cls') != -1:
cls = xgboost.XGBClassifier()
cls.load_model(path)
if name.find('0.90') == -1:
assert len(cls.classes_) == gm.kClasses
assert len(cls._le.classes_) == gm.kClasses
assert cls.n_classes_ == gm.kClasses
assert (len(cls.get_booster().get_dump()) ==
gm.kRounds * gm.kForests * gm.kClasses), path
config = json.loads(cls.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'multi:softprob', path
run_model_param_check(config)
elif name.find('ltr') != -1:
ltr = xgboost.XGBRanker()
ltr.load_model(path)
assert (len(ltr.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
config = json.loads(ltr.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'rank:ndcg'
run_model_param_check(config)
elif name.find('logit') != -1:
logit = xgboost.XGBClassifier()
logit.load_model(path)
assert (len(logit.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
config = json.loads(logit.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'binary:logistic'
else:
assert False
@pytest.mark.ci
def test_model_compatibility():
'''Test model compatibility, can only be run on CI as others don't
have the credentials.
'''
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, 'models')
try:
import boto3
import botocore
except ImportError:
pytest.skip(
'Skiping compatibility tests as boto3 is not installed.')
try:
s3_bucket = boto3.resource('s3').Bucket('xgboost-ci-jenkins-artifacts')
zip_path = 'xgboost_model_compatibility_test.zip'
s3_bucket.download_file(zip_path, zip_path)
except botocore.exceptions.NoCredentialsError:
pytest.skip(
'Skiping compatibility tests as running on non-CI environment.')
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(path)
models = [
os.path.join(root, f) for root, subdir, files in os.walk(path)
for f in files
if f != 'version'
]
assert models
for path in models:
name = os.path.basename(path)
if name.startswith('xgboost-'):
booster = xgboost.Booster(model_file=path)
run_booster_check(booster, name)
elif name.startswith('xgboost_scikit'):
run_scikit_model_check(name, path)
else:
assert False

View File

@@ -115,7 +115,6 @@ class TestRanking(unittest.TestCase):
# model training parameters # model training parameters
cls.params = {'objective': 'rank:pairwise', cls.params = {'objective': 'rank:pairwise',
'booster': 'gbtree', 'booster': 'gbtree',
'silent': 0,
'eval_metric': ['ndcg'] 'eval_metric': ['ndcg']
} }
@@ -153,7 +152,8 @@ class TestRanking(unittest.TestCase):
Test cross-validation with a group specified Test cross-validation with a group specified
""" """
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500, cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
early_stopping_rounds=10, shuffle=False, nfold=10, as_pandas=False) early_stopping_rounds=10, shuffle=False, nfold=10,
as_pandas=False)
assert isinstance(cv, dict) assert isinstance(cv, dict)
assert len(cv) == 4 assert len(cv) == 4

View File

@@ -10,18 +10,20 @@ if sys.platform.startswith("win"):
pytestmark = pytest.mark.skipif(**tm.no_dask()) pytestmark = pytest.mark.skipif(**tm.no_dask())
try: try:
from distributed.utils_test import client, loop, cluster_fixture from distributed import LocalCluster, Client
import dask.dataframe as dd import dask.dataframe as dd
import dask.array as da import dask.array as da
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
except ImportError: except ImportError:
client = None LocalCluster = None
loop = None Client = None
cluster_fixture = None dd = None
pass da = None
DaskDMatrix = None
kRows = 1000 kRows = 1000
kCols = 10 kCols = 10
kWorkers = 5
def generate_array(): def generate_array():
@@ -31,7 +33,9 @@ def generate_array():
return X, y return X, y
def test_from_dask_dataframe(client): def test_from_dask_dataframe():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X, y = generate_array() X, y = generate_array()
X = dd.from_dask_array(X) X = dd.from_dask_array(X)
@@ -51,11 +55,13 @@ def test_from_dask_dataframe(client):
# evals_result is not supported in dask interface. # evals_result is not supported in dask interface.
xgb.dask.train( xgb.dask.train(
client, {}, dtrain, num_boost_round=2, evals_result={}) client, {}, dtrain, num_boost_round=2, evals_result={})
# force prediction to be computed
prediction = prediction.compute() # force prediction to be computed prediction = prediction.compute()
def test_from_dask_array(client): def test_from_dask_array():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X, y = generate_array() X, y = generate_array()
dtrain = DaskDMatrix(client, X, y) dtrain = DaskDMatrix(client, X, y)
# results is {'booster': Booster, 'history': {...}} # results is {'booster': Booster, 'history': {...}}
@@ -65,11 +71,13 @@ def test_from_dask_array(client):
assert prediction.shape[0] == kRows assert prediction.shape[0] == kRows
assert isinstance(prediction, da.Array) assert isinstance(prediction, da.Array)
# force prediction to be computed
prediction = prediction.compute() # force prediction to be computed prediction = prediction.compute()
def test_regressor(client): def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X, y = generate_array() X, y = generate_array()
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method='hist') regressor.set_params(tree_method='hist')
@@ -89,10 +97,13 @@ def test_regressor(client):
assert len(history['validation_0']['rmse']) == 2 assert len(history['validation_0']['rmse']) == 2
def test_classifier(client): def test_dask_classifier():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
X, y = generate_array() X, y = generate_array()
y = (y * 10).astype(np.int32) y = (y * 10).astype(np.int32)
classifier = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2) classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2)
classifier.client = client classifier.client = client
classifier.fit(X, y, eval_set=[(X, y)]) classifier.fit(X, y, eval_set=[(X, y)])
prediction = classifier.predict(X) prediction = classifier.predict(X)
@@ -164,11 +175,15 @@ def run_empty_dmatrix(client, parameters):
# No test for Exact, as empty DMatrix handling are mostly for distributed # No test for Exact, as empty DMatrix handling are mostly for distributed
# environment and Exact doesn't support it. # environment and Exact doesn't support it.
def test_empty_dmatrix_hist(client): def test_empty_dmatrix_hist():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
parameters = {'tree_method': 'hist'} parameters = {'tree_method': 'hist'}
run_empty_dmatrix(client, parameters) run_empty_dmatrix(client, parameters)
def test_empty_dmatrix_approx(client): def test_empty_dmatrix_approx():
with LocalCluster(n_workers=5) as cluster:
with Client(cluster) as client:
parameters = {'tree_method': 'approx'} parameters = {'tree_method': 'approx'}
run_empty_dmatrix(client, parameters) run_empty_dmatrix(client, parameters)

View File

@@ -490,6 +490,13 @@ def test_kwargs():
assert clf.get_params()['n_estimators'] == 1000 assert clf.get_params()['n_estimators'] == 1000
def test_kwargs_error():
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
with pytest.raises(TypeError):
clf = xgb.XGBClassifier(n_jobs=1000, **params)
assert isinstance(clf, xgb.XGBClassifier)
def test_kwargs_grid_search(): def test_kwargs_grid_search():
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
from sklearn import datasets from sklearn import datasets
@@ -510,13 +517,6 @@ def test_kwargs_grid_search():
assert len(means) == len(set(means)) assert len(means) == len(set(means))
def test_kwargs_error():
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
with pytest.raises(TypeError):
clf = xgb.XGBClassifier(n_jobs=1000, **params)
assert isinstance(clf, xgb.XGBClassifier)
def test_sklearn_clone(): def test_sklearn_clone():
from sklearn.base import clone from sklearn.base import clone
@@ -525,6 +525,17 @@ def test_sklearn_clone():
clone(clf) clone(clf)
def test_sklearn_get_default_params():
from sklearn.datasets import load_digits
digits_2class = load_digits(2)
X = digits_2class['data']
y = digits_2class['target']
cls = xgb.XGBClassifier()
assert cls.get_params()['base_score'] is None
cls.fit(X[:4, ...], y[:4, ...])
assert cls.get_params()['base_score'] is not None
def test_validation_weights_xgbmodel(): def test_validation_weights_xgbmodel():
from sklearn.datasets import make_hastie_10_2 from sklearn.datasets import make_hastie_10_2