diff --git a/Makefile b/Makefile
index e3f3134e4..abe8ccfaa 100644
--- a/Makefile
+++ b/Makefile
@@ -73,7 +73,7 @@ endif
# specify tensor path
-.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java
+.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
@@ -131,8 +131,11 @@ rcpplint:
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src
lint: rcpplint
- python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin
+ python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin python-package
+pylint:
+ flake8 --ignore E501 python-package
+ flake8 --ignore E501 tests/python
clean:
$(RM) -rf build build_plugin lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost
diff --git a/NEWS.md b/NEWS.md
index 81afdbb5a..35cfc74f5 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -31,7 +31,9 @@ This file records the changes in xgboost library in reverse chronological order.
* JVM Package
- Enable xgboost4j for java and scala
- XGBoost distributed now runs on Flink and Spark.
-
+* Support model attributes listing for meta data.
+ - https://github.com/dmlc/xgboost/pull/1198
+ - https://github.com/dmlc/xgboost/pull/1166
## v0.47 (2016.01.14)
diff --git a/README.md b/README.md
index edcb68043..8b9c10168 100644
--- a/README.md
+++ b/README.md
@@ -43,12 +43,7 @@ License
-------
© Contributors, 2016. Licensed under an [Apache-2](https://github.com/dmlc/xgboost/blob/master/LICENSE) license.
-
Reference
---------
-- Tianqi Chen and Carlos Guestrin. [XGBoost: A Scalable Tree Boosting System](http://arxiv.org/abs/1603.02754). Arxiv.1603.02754
+- Tianqi Chen and Carlos Guestrin. [XGBoost: A Scalable Tree Boosting System](http://arxiv.org/abs/1603.02754). In 22nd SIGKDD Conference on Knowledge Discovery and Data Mining, 2016
- XGBoost originates from research project at University of Washington, see also the [Project Page at UW](http://dmlc.cs.washington.edu/xgboost.html).
-
-Acknowledgements
-----------------
-- This work was supported in part by ONR (PECASE) N000141010672, NSF IIS 1258741 and the TerraSwarm Research Center sponsored by MARCO and DARPA.
diff --git a/demo/README.md b/demo/README.md
index e4117cd8b..fb5505fbc 100644
--- a/demo/README.md
+++ b/demo/README.md
@@ -66,6 +66,7 @@ However, the parameter settings can be applied to all versions
- [Starter script for Kaggle Higgs Boson](kaggle-higgs)
- [Kaggle Tradeshift winning solution by daxiongshu](https://github.com/daxiongshu/kaggle-tradeshift-winning-solution)
+- [Benchmarking the most commonly used open source tools for binary classification](https://github.com/szilard/benchm-ml#boosting-gradient-boosted-treesgradient-boosting-machines)
## Machine Learning Challenge Winning Solutions
@@ -85,6 +86,10 @@ Please send pull requests if you find ones that are missing here.
- Owen Zhang, 1st place of the [Avito Context Ad Clicks competition](https://www.kaggle.com/c/avito-context-ad-clicks). Link to [the Kaggle interview](http://blog.kaggle.com/2015/08/26/avito-winners-interview-1st-place-owen-zhang/).
- Keiichi Kuroyanagi, 2nd place of the [Airbnb New User Bookings](https://www.kaggle.com/c/airbnb-recruiting-new-user-bookings). Link to [the Kaggle interview](http://blog.kaggle.com/2016/03/17/airbnb-new-user-bookings-winners-interview-2nd-place-keiichi-kuroyanagi-keiku/).
- Marios Michailidis, Mathias Müller and Ning Situ, 1st place [Homesite Quote Conversion](https://www.kaggle.com/c/homesite-quote-conversion). Link to [the Kaggle interview](http://blog.kaggle.com/2016/04/08/homesite-quote-conversion-winners-write-up-1st-place-kazanova-faron-clobber/).
+
+## Talks
+- [XGBoost: A Scalable Tree Boosting System](http://datascience.la/xgboost-workshop-and-meetup-talk-with-tianqi-chen/) (video+slides) by Tianqi Chen at the Los Angeles Data Science meetup
+
## Tutorials
- [XGBoost Official RMarkdown Tutorials](https://xgboost.readthedocs.org/en/latest/R-package/index.html#tutorials)
diff --git a/demo/guide-python/cross_validation.py b/demo/guide-python/cross_validation.py
index 6ca13d460..5c8ee0b1b 100755
--- a/demo/guide-python/cross_validation.py
+++ b/demo/guide-python/cross_validation.py
@@ -12,15 +12,18 @@ print ('running cross validation')
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
- metrics={'error'}, seed = 0)
+ metrics={'error'}, seed = 0,
+ callbacks=[xgb.callback.print_evaluation(show_stdv=True)])
print ('running cross validation, disable standard deviation display')
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
-xgb.cv(param, dtrain, num_round, nfold=5,
- metrics={'error'}, seed = 0, show_stdv = False)
-
+res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
+ metrics={'error'}, seed = 0,
+ callbacks=[xgb.callback.print_evaluation(show_stdv=False),
+ xgb.callback.early_stop(3)])
+print (res)
print ('running cross validation, with preprocessing function')
# define the preprocessing function
# used to return the preprocessed training, test data, and parameter
@@ -58,4 +61,3 @@ param = {'max_depth':2, 'eta':1, 'silent':1}
# train with customized objective
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
obj = logregobj, feval=evalerror)
-
diff --git a/dmlc-core b/dmlc-core
index 1db0792e1..9fd3b4846 160000
--- a/dmlc-core
+++ b/dmlc-core
@@ -1 +1 @@
-Subproject commit 1db0792e1a55355b1f07699bba18c88ded996953
+Subproject commit 9fd3b48462a7a651e12a197679f71e043dcb25a2
diff --git a/doc/R-package/index.md b/doc/R-package/index.md
index 92df95e9f..6333d43fb 100644
--- a/doc/R-package/index.md
+++ b/doc/R-package/index.md
@@ -9,7 +9,7 @@ You have find XGBoost R Package!
Get Started
-----------
* Checkout the [Installation Guide](../build.md) contains instructions to install xgboost, and [Tutorials](#tutorials) for examples on how to use xgboost for various tasks.
-* Please visit [walk through example](demo).
+* Please visit [walk through example](../../R-package/demo).
Tutorials
---------
diff --git a/doc/_static/cn.svg b/doc/_static/cn.svg
new file mode 100644
index 000000000..515176d60
--- /dev/null
+++ b/doc/_static/cn.svg
@@ -0,0 +1,20 @@
+
+
+
+ Flag of the People's Republic of China
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/doc/_static/js/auto_module_index.js b/doc/_static/js/auto_module_index.js
new file mode 100644
index 000000000..b918ecdc1
--- /dev/null
+++ b/doc/_static/js/auto_module_index.js
@@ -0,0 +1,25 @@
+function auto_index(module) {
+ $(document).ready(function () {
+ // find all classes or functions
+ var div_query = "div[class='section'][id='module-" + module + "']";
+ var class_query = div_query + " dl[class='class'] > dt";
+ var func_query = div_query + " dl[class='function'] > dt";
+ var targets = $(class_query + ',' + func_query);
+
+ var li_node = $("li a[href='#module-" + module + "']").parent();
+ var html = "
";
+
+ for (var i = 0; i < targets.length; ++i) {
+ var id = $(targets[i]).attr('id');
+ // remove 'mxnet.' prefix to make menus shorter
+ var id_simple = id.replace(/^mxnet\./, '');
+ html += "" + id_simple + " ";
+ }
+
+ html += " ";
+ li_node.append(html);
+ });
+}
+
diff --git a/doc/_static/us.svg b/doc/_static/us.svg
new file mode 100644
index 000000000..1d621f96d
--- /dev/null
+++ b/doc/_static/us.svg
@@ -0,0 +1,117 @@
+
+
+
+
+
+ image/svg+xml
+
+
+
+
+ The United States of America flag, produced by Daniel McRae
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/doc/_static/xgboost-theme/footer.html b/doc/_static/xgboost-theme/footer.html
new file mode 100644
index 000000000..148bcb7df
--- /dev/null
+++ b/doc/_static/xgboost-theme/footer.html
@@ -0,0 +1,5 @@
+
+
+
diff --git a/doc/_static/xgboost-theme/index.html b/doc/_static/xgboost-theme/index.html
new file mode 100644
index 000000000..852e00a1a
--- /dev/null
+++ b/doc/_static/xgboost-theme/index.html
@@ -0,0 +1,58 @@
+
+
+
+
+
Scalable and Flexible Gradient Boosting
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Flexible
+
Supports regression, classification, ranking and user defined objectives.
+
+
+
+
Portable
+
Runs on Windows, Lunix and OSX, as well as various cloud Platforms
+
+
+
Multiple Lanuages
+
Supports multiple languages including C++, Python, R, Java, Scala, Julia.
+
+
+
Battle-tested
+
Wins many data science and machine learning challenges.
+ Used in production by multiple companies.
+
+
+
+
Distributed on Cloud
+
Supports distributed training on multiple machines, including AWS,
+ GCE, Azure, and Yarn clusters. Can be integrated with Flink, Spark and other cloud dataflow systems.
+
+
+
Performance
+
The well-optimized backend system for the best performance with limited resources.
+ The distributed version solves problems beyond billions of examples with same code.
+
+
+
+
+
+
diff --git a/doc/_static/xgboost-theme/layout.html b/doc/_static/xgboost-theme/layout.html
new file mode 100644
index 000000000..2931fe594
--- /dev/null
+++ b/doc/_static/xgboost-theme/layout.html
@@ -0,0 +1,156 @@
+{%- block doctype -%}
+
+{%- endblock %}
+{%- set reldelim1 = reldelim1 is not defined and ' »' or reldelim1 %}
+{%- set reldelim2 = reldelim2 is not defined and ' |' or reldelim2 %}
+{%- set render_sidebar = (not embedded) and (not theme_nosidebar|tobool) and
+ (sidebars != []) %}
+{%- set url_root = pathto('', 1) %}
+{%- if url_root == '#' %}{% set url_root = '' %}{% endif %}
+{%- if not embedded and docstitle %}
+ {%- set titlesuffix = " — "|safe + docstitle|e %}
+{%- else %}
+ {%- set titlesuffix = "" %}
+{%- endif %}
+
+{%- macro searchform(classes, button) %}
+
+{%- endmacro %}
+
+{%- macro sidebarglobal() %}
+
+ {{ toctree(maxdepth=2|toint, collapse=False,includehidden=theme_globaltoc_includehidden|tobool) }}
+
+{%- endmacro %}
+
+{%- macro sidebar() %}
+ {%- if render_sidebar %}
+
+ {%- endif %}
+{%- endmacro %}
+
+
+{%- macro script() %}
+
+
+ {% for name in ['jquery.js', 'underscore.js', 'doctools.js', 'searchtools.js'] %}
+
+ {% endfor %}
+
+
+
+
+
+
+{%- endmacro %}
+
+{%- macro css() %}
+
+ {% if pagename == 'index' %}
+
+ {%- else %}
+
+
+ {%- endif %}
+
+
+{%- endmacro %}
+
+
+
+
+
+
+ {# The above 3 meta tags *must* come first in the head; any other head content
+ must come *after* these tags. #}
+ {{ metatags }}
+ {%- block htmltitle %}
+ {%- if pagename != 'index' %}
+ {{ title|striptags|e }}{{ titlesuffix }}
+ {%- else %}
+ XGBoost Documents
+ {%- endif %}
+ {%- endblock %}
+ {{ css() }}
+ {%- if not embedded %}
+ {{ script() }}
+ {%- if use_opensearch %}
+
+ {%- endif %}
+ {%- if favicon %}
+
+ {%- endif %}
+ {%- endif %}
+{%- block linktags %}
+ {%- if hasdoc('about') %}
+
+ {%- endif %}
+ {%- if hasdoc('genindex') %}
+
+ {%- endif %}
+ {%- if hasdoc('search') %}
+
+ {%- endif %}
+ {%- if hasdoc('copyright') %}
+
+ {%- endif %}
+ {%- if parents %}
+
+ {%- endif %}
+ {%- if next %}
+
+ {%- endif %}
+ {%- if prev %}
+
+ {%- endif %}
+{%- endblock %}
+{%- block extrahead %} {% endblock %}
+
+
+
+
+ {%- include "navbar.html" %}
+
+ {% if pagename != 'index' %}
+
+
+ {{ sidebar() }}
+
+ {% block body %} {% endblock %}
+ {%- include "footer.html" %}
+
+
+
+ {%- else %}
+ {%- include "index.html" %}
+ {%- include "footer.html" %}
+ {%- endif %}
+
+
+
+
diff --git a/doc/_static/xgboost-theme/navbar.html b/doc/_static/xgboost-theme/navbar.html
new file mode 100644
index 000000000..44c66cf9c
--- /dev/null
+++ b/doc/_static/xgboost-theme/navbar.html
@@ -0,0 +1,40 @@
+
+
+
+
+
+ XGBoost
+ {% for name in ['Get Started', 'Tutorials', 'How To'] %}
+ {{name}}
+ {% endfor %}
+ {% for name in ['Packages'] %}
+
+ {{name}}
+
+
+ {% endfor %}
+ Knobs
+ {{searchform('', False)}}
+
+
+
+
+
diff --git a/doc/_static/xgboost-theme/theme.conf b/doc/_static/xgboost-theme/theme.conf
new file mode 100644
index 000000000..89e03bbda
--- /dev/null
+++ b/doc/_static/xgboost-theme/theme.conf
@@ -0,0 +1,2 @@
+[theme]
+inherit = basic
diff --git a/doc/_static/xgboost.css b/doc/_static/xgboost.css
new file mode 100644
index 000000000..f4862a706
--- /dev/null
+++ b/doc/_static/xgboost.css
@@ -0,0 +1,232 @@
+/* header section */
+.splash{
+ padding:5em 0 1em 0;
+ background-color:#0079b2;
+ /* background-image:url(../img/bg.jpg); */
+ background-size:cover;
+ background-attachment:fixed;
+ color:#fff;
+ text-align:center
+}
+
+.splash h1{
+ font-size: 40px;
+ margin-bottom: 20px;
+}
+.splash .social{
+ margin:2em 0
+}
+
+.splash .get_start {
+ margin:2em 0
+}
+
+.splash .get_start_btn {
+ border: 2px solid #FFFFFF;
+ border-radius: 5px;
+ color: #FFFFFF;
+ display: inline-block;
+ font-size: 26px;
+ padding: 9px 20px;
+}
+
+.section-tout{
+ padding:3em 0 3em;
+ border-bottom:1px solid rgba(0,0,0,.05);
+ background-color:#eaf1f1
+}
+.section-tout .fa{
+ margin-right:.5em
+}
+
+.section-tout h3{
+ font-size:20px;
+}
+
+.section-tout p {
+ margin-bottom:2em
+}
+
+.section-inst{
+ padding:3em 0 3em;
+ border-bottom:1px solid rgba(0,0,0,.05);
+
+ text-align:center
+}
+
+.section-inst p {
+ margin-bottom:2em
+}
+.section-inst img {
+ -webkit-filter: grayscale(90%); /* Chrome, Safari, Opera */
+ filter: grayscale(90%);
+ margin-bottom:2em
+}
+.section-inst img:hover {
+ -webkit-filter: grayscale(0%); /* Chrome, Safari, Opera */
+ filter: grayscale(0%);
+}
+
+.footer{
+ padding-top: 40px;
+}
+.footer li{
+ float:right;
+ margin-right:1.5em;
+ margin-bottom:1.5em
+}
+.footer p{
+ font-size: 15px;
+ color: #888;
+ clear:right;
+ margin-bottom:0
+}
+
+
+/* sidebar */
+div.sphinxsidebar {
+ margin-top: 20px;
+ margin-left: 0;
+ position: fixed;
+ overflow-y: scroll;
+ width: 250px;
+ top: 52px;
+ bottom: 0;
+ display: none
+}
+div.sphinxsidebar ul { padding: 0 }
+div.sphinxsidebar ul ul { margin-left: 15px }
+
+@media (min-width:1200px) {
+ .content { float: right; width: 66.66666667%; margin-right: 5% }
+ div.sphinxsidebar {display: block}
+}
+
+
+.github-btn { border: 0; overflow: hidden }
+
+.container {
+ margin-right: auto;
+ margin-left: auto;
+ padding-left: 15px;
+ padding-right: 15px
+}
+
+body>.container {
+ padding-top: 80px
+}
+
+body {
+ font-size: 16px;
+}
+
+pre {
+ font-size: 14px;
+}
+
+/* navbar */
+.navbar {
+ background-color:#0079b2;
+ border: 0px;
+ height: 65px;
+}
+.navbar-right li {
+ display:inline-block;
+ vertical-align:top;
+ padding: 22px 4px;
+}
+
+.navbar-left li {
+ display:inline-block;
+ vertical-align:top;
+ padding: 17px 10px;
+ /* margin: 0 5px; */
+}
+
+.navbar-left li a {
+ font-size: 22px;
+ color: #fff;
+}
+
+.navbar-left > li > a:hover{
+ color:#fff;
+}
+.flag-icon {
+ background-size: contain;
+ background-position: 50%;
+ background-repeat: no-repeat;
+ position: relative;
+ display: inline-block;
+ width: 1.33333333em;
+ line-height: 1em;
+}
+
+.flag-icon:before {
+ content: "\00a0";
+}
+
+
+.flag-icon-cn {
+ background-image: url(./cn.svg);
+}
+
+.flag-icon-us {
+ background-image: url(./us.svg);
+}
+
+
+/* .flags { */
+/* padding: 10px; */
+/* } */
+
+.navbar-brand >img {
+ width: 110px;
+}
+
+.dropdown-menu li {
+ padding: 0px 0px;
+ width: 120px;
+}
+.dropdown-menu li a {
+ color: #0079b2;
+ font-size: 20px;
+}
+
+.section h1 {
+ padding-top: 90px;
+ margin-top: -60px;
+ padding-bottom: 10px;
+ font-size: 28px;
+}
+
+.section h2 {
+ padding-top: 80px;
+ margin-top: -60px;
+ padding-bottom: 10px;
+ font-size: 22px;
+}
+
+.section h3 {
+ padding-top: 80px;
+ margin-top: -64px;
+ padding-bottom: 8px;
+}
+
+.section h4 {
+ padding-top: 80px;
+ margin-top: -64px;
+ padding-bottom: 8px;
+}
+
+dt {
+ margin-top: -76px;
+ padding-top: 76px;
+}
+
+dt:target, .highlighted {
+ background-color: #fff;
+}
+
+.section code.descname {
+ font-size: 1em;
+}
diff --git a/doc/cli/index.md b/doc/cli/index.md
new file mode 100644
index 000000000..23611d699
--- /dev/null
+++ b/doc/cli/index.md
@@ -0,0 +1,3 @@
+# XGBoost Command Line version
+
+See [XGBoost Command Line walkthrough](https://github.com/dmlc/xgboost/blob/master/demo/binary_classification/README.md)
diff --git a/doc/conf.py b/doc/conf.py
index b2f82e726..65ebddccd 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -117,10 +117,11 @@ todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
+html_theme_path = ['_static']
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
# html_theme = 'alabaster'
-html_theme = 'sphinx_rtd_theme'
+html_theme = 'xgboost-theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
diff --git a/doc/get_started/index.md b/doc/get_started/index.md
new file mode 100644
index 000000000..50c960421
--- /dev/null
+++ b/doc/get_started/index.md
@@ -0,0 +1,80 @@
+# Get Started with XGBoost
+
+This is a quick started tutorial showing snippets for you to quickly try out xgboost
+on the demo dataset on a binary classification task.
+
+## Links to Helpful Other Resources
+- See [Installation Guide](../build.md) on how to install xgboost.
+- See [How to pages](../how_to/index.md) on various tips on using xgboost.
+- See [Tutorials](../tutorials/index.md) on tutorials on specific tasks.
+- See [Learning to use XGBoost by Examples](../../demo) for more code examples.
+
+## Python
+```python
+import xgboost as xgb
+# read in data
+dtrain = xgb.DMatrix('demo/data/agaricus.txt.train')
+dtest = xgb.DMatrix('demo/data/agaricus.txt.test')
+# specify parameters via map
+param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
+num_round = 2
+bst = xgb.train(param, dtrain, num_round)
+# make prediction
+preds = bst.predict(dtest)
+```
+
+## R
+
+```r
+# load data
+data(agaricus.train, package='xgboost')
+data(agaricus.test, package='xgboost')
+train <- agaricus.train
+test <- agaricus.test
+# fit model
+bst <- xgboost(data = train$data, label = train$label, max.depth = 2, eta = 1, nround = 2,
+ nthread = 2, objective = "binary:logistic")
+# predict
+pred <- predict(bst, test$data)
+
+```
+
+## Julia
+```julia
+using XGBoost
+# read data
+train_X, train_Y = readlibsvm("demo/data/agaricus.txt.train", (6513, 126))
+test_X, test_Y = readlibsvm("demo/data/agaricus.txt.test", (1611, 126))
+# fit model
+num_round = 2
+bst = xgboost(train_X, num_round, label=train_Y, eta=1, max_depth=2)
+# predict
+pred = predict(bst, test_X)
+```
+
+## Scala
+```scala
+import ml.dmlc.xgboost4j.scala.DMatrix
+import ml.dmlc.xgboost4j.scala.XGBoost
+
+object XGBoostScalaExample {
+ def main(args: Array[String]) {
+ // read trainining data, available at xgboost/demo/data
+ val trainData =
+ new DMatrix("/path/to/agaricus.txt.train")
+ // define parameters
+ val paramMap = List(
+ "eta" -> 0.1,
+ "max_depth" -> 2,
+ "objective" -> "binary:logistic").toMap
+ // number of iterations
+ val round = 2
+ // train the model
+ val model = XGBoost.train(trainData, paramMap, round)
+ // run prediction
+ val predTrain = model.predict(trainData)
+ // save model to the file.
+ model.saveModel("/local/path/to/model")
+ }
+}
+```
diff --git a/doc/dev-guide/contribute.md b/doc/how_to/contribute.md
similarity index 100%
rename from doc/dev-guide/contribute.md
rename to doc/how_to/contribute.md
diff --git a/doc/external_memory.md b/doc/how_to/external_memory.md
similarity index 100%
rename from doc/external_memory.md
rename to doc/how_to/external_memory.md
diff --git a/doc/how_to/index.md b/doc/how_to/index.md
new file mode 100644
index 000000000..afc69f777
--- /dev/null
+++ b/doc/how_to/index.md
@@ -0,0 +1,16 @@
+# XGBoost How To
+
+This page contains guidelines to use and develop mxnets.
+
+## Installation
+- [How to Install XGBoost](../build.md)
+
+## Use XGBoost in Specific Ways
+- [Parameter tunning guide](param_tuning.md)
+- [Use out of core computation for large dataset](external_memory.md)
+
+## Develop and Hack XGBoost
+- [Contribute to XGBoost](contribute.md)
+
+## Frequently Ask Questions
+- [FAQ](../faq.md)
diff --git a/doc/param_tuning.md b/doc/how_to/param_tuning.md
similarity index 100%
rename from doc/param_tuning.md
rename to doc/how_to/param_tuning.md
diff --git a/doc/index.md b/doc/index.md
index 9d95944e6..55d90d95b 100644
--- a/doc/index.md
+++ b/doc/index.md
@@ -1,59 +1,15 @@
XGBoost Documentation
=====================
-This is document of xgboost library.
-XGBoost is short for eXtreme gradient boosting. This is a library that is designed, and optimized for boosted (tree) algorithms.
-The goal of this library is to push the extreme of the computation limits of machines to provide a ***scalable***, ***portable*** and ***accurate***
-for large scale tree boosting.
-
This document is hosted at http://xgboost.readthedocs.org/. You can also browse most of the documents in github directly.
-Package Documents
------------------
-This section contains language specific package guide.
-* [XGBoost Command Line Usage Walkthrough](../demo/binary_classification/README.md)
+These are used to generate the index used in search.
+
* [Python Package Document](python/index.md)
* [R Package Document](R-package/index.md)
* [Java/Scala Package Document](jvm/index.md)
-* [XGBoost.jl Julia Package](https://github.com/dmlc/XGBoost.jl)
-
-User Guides
------------
-This section contains users guides that are general across languages.
-* [Installation Guide](build.md)
-* [Introduction to Boosted Trees](model.md)
-* [Distributed Training Tutorial](tutorial/aws_yarn.md)
-* [Frequently Asked Questions](faq.md)
-* [External Memory Version](external_memory.md)
-* [Learning to use XGBoost by Example](../demo)
-* [Parameters](parameter.md)
-* [Text input format](input_format.md)
-* [Notes on Parameter Tunning](param_tuning.md)
-
-
-Tutorials
----------
-This section contains official tutorials of XGBoost package.
-See [Awesome XGBoost](https://github.com/dmlc/xgboost/tree/master/demo) for links to mores resources.
-* [Introduction to XGBoost in R](R-package/xgboostPresentation.md) (R package)
- - This is a general presentation about xgboost in R.
-* [Discover your data with XGBoost in R](R-package/discoverYourData.md) (R package)
- - This tutorial explaining feature analysis in xgboost.
-* [Introduction of XGBoost in Python](python/python_intro.md) (python)
- - This tutorial introduces the python package of xgboost
-* [Understanding XGBoost Model on Otto Dataset](../demo/kaggle-otto/understandingXGBoostModel.Rmd) (R package)
- - This tutorial teaches you how to use xgboost to compete kaggle otto challenge.
-
-Developer Guide
----------------
-* [Contributor Guide](dev-guide/contribute.md)
-
-
-Indices and tables
-------------------
-
-```eval_rst
-* :ref:`genindex`
-* :ref:`modindex`
-* :ref:`search`
-```
+* [Julia Package Document](julia/index.md)
+* [CLI Package Document](cli/index.md)
+- [Howto Documents](how_to/index.md)
+- [Get Started Documents](get_started/index.md)
+- [Tutorials](tutorials/index.md)
diff --git a/doc/julia/index.md b/doc/julia/index.md
new file mode 100644
index 000000000..470e09644
--- /dev/null
+++ b/doc/julia/index.md
@@ -0,0 +1,3 @@
+# XGBoost.jl
+
+See [XGBoost.jl Project page](https://github.com/dmlc/XGBoost.jl)
\ No newline at end of file
diff --git a/doc/jvm/index.md b/doc/jvm/index.md
index e9a16477e..e3ff666c0 100644
--- a/doc/jvm/index.md
+++ b/doc/jvm/index.md
@@ -7,20 +7,24 @@ You have find XGBoost JVM Package!
Installation
------------
-Currently, XGBoost4J only support installation from source. Building XGBoost4J using Maven requires Maven 3 or newer and Java 7+.
+Currently, XGBoost4J only support installation from source. Building XGBoost4J using Maven requires Maven 3 or newer and Java 7+.
Before you install XGBoost4J, you need to define environment variable `JAVA_HOME` as your JDK directory to ensure that your compiler can find `jni.h` correctly, since XGBoost4J relies on JNI to implement the interaction between the JVM and native libraries.
-After your `JAVA_HOME` is defined correctly, it is as simple as run `mvn package` under jvm-packages directory to install XGBoost4J.
+After your `JAVA_HOME` is defined correctly, it is as simple as run `mvn package` under jvm-packages directory to install XGBoost4J. You can also skip the tests by running `mvn -DskipTests=true package`, if you are sure about the correctness of your local setup.
-NOTE: XGBoost4J requires to run with Spark 1.6 or newer
+XGBoost4J-Spark which integrates XGBoost with Spark requires to run with Spark 1.6 or newer (the default version is 1.6.1). You can build XGBoost4J-Spark as a component of XGBoost4J by running `mvn package` or specifying the spark version by `mvn -Dspark.version=1.6.0 package`.
Contents
--------
* [Java Overview Tutorial](java_intro.md)
+
+Resources
+---------
* [Code Examples](https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example)
* [Java API Docs](http://dmlc.ml/docs/javadocs/index.html)
-* [Scala API Docs]
+
+## Scala API Docs
* [XGBoost4J](http://dmlc.ml/docs/scaladocs/xgboost4j/index.html)
* [XGBoost4J-Spark](http://dmlc.ml/docs/scaladocs/xgboost4j-spark/index.html)
- * [XGBoost4J-Flink](http://dmlc.ml/docs/scaladocs/xgboost4j-flink/index.html)
\ No newline at end of file
+ * [XGBoost4J-Flink](http://dmlc.ml/docs/scaladocs/xgboost4j-flink/index.html)
diff --git a/doc/parameter.md b/doc/parameter.md
index 3ca79077b..70575343e 100644
--- a/doc/parameter.md
+++ b/doc/parameter.md
@@ -13,7 +13,8 @@ In R-package, you can use .(dot) to replace under score in the parameters, for e
General Parameters
------------------
* booster [default=gbtree]
- - which booster to use, can be gbtree or gblinear. gbtree uses tree based model while gblinear uses linear function.
+ - which booster to use, can be gbtree, gblinear or dart.
+ gbtree and dart use tree based model while gblinear uses linear function.
* silent [default=0]
- 0 means printing running messages, 1 means silent mode.
* nthread [default to maximum number of threads available if not set]
@@ -72,7 +73,29 @@ Parameters for Tree Booster
but consider set to lower number for more accurate enumeration.
- range: (0, 1)
* scale_pos_weight, [default=0]
- - Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases) See [Parameters Tuning](param_tuning.md) for more discussion. Also see Higgs Kaggle competition demo for examples: [R](../demo/kaggle-higgs/higgs-train.R ), [py1](../demo/kaggle-higgs/higgs-numpy.py ), [py2](../demo/kaggle-higgs/higgs-cv.py ), [py3](../demo/guide-python/cross_validation.py)
+ - Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases) See [Parameters Tuning](how_to/param_tuning.md) for more discussion. Also see Higgs Kaggle competition demo for examples: [R](../demo/kaggle-higgs/higgs-train.R ), [py1](../demo/kaggle-higgs/higgs-numpy.py ), [py2](../demo/kaggle-higgs/higgs-cv.py ), [py3](../demo/guide-python/cross_validation.py)
+
+Additional parameters for Dart Booster
+--------------------------------------
+* sample_type [default="uniform"]
+ - type of sampling algorithm.
+ - "uniform": dropped trees are selected uniformly.
+ - "weighted": dropped trees are selected in proportion to weight.
+* normalize_type [default="tree]
+ - type of normalization algorithm.
+ - "tree": New trees have the same weight of each of dropped trees.
+ weight of new trees are learning_rate / (k + learnig_rate)
+ dropped trees are scaled by a factor of k / (k + learning_rate)
+ - "forest": New trees have the same weight of sum of dropped trees (forest).
+ weight of new trees are learning_rate / (1 + learning_rate)
+ dropped trees are scaled by a factor of 1 / (1 + learning_rate)
+* rate_drop [default=0.0]
+ - dropout rate.
+ - range: [0.0, 1.0]
+* skip_drop [default=0.0]
+ - probability of skip dropout.
+ If a dropout is skipped, new trees are added in the same manner as gbtree.
+ - range: [0.0, 1.0]
Parameters for Linear Booster
-----------------------------
diff --git a/doc/tutorial/aws_yarn.md b/doc/tutorials/aws_yarn.md
similarity index 100%
rename from doc/tutorial/aws_yarn.md
rename to doc/tutorials/aws_yarn.md
diff --git a/doc/tutorials/index.md b/doc/tutorials/index.md
new file mode 100644
index 000000000..a4edf51c1
--- /dev/null
+++ b/doc/tutorials/index.md
@@ -0,0 +1,8 @@
+# XGBoost Tutorials
+
+This section contains official tutorials inside XGBoost package.
+See [Awesome XGBoost](https://github.com/dmlc/xgboost/tree/master/demo) for links to mores resources.
+
+## Contents
+- [Introduction to Boosted Trees](../model.md)
+- [Distributed XGBoost YARN on AWS](aws_yarn.md)
diff --git a/jvm-packages/README.md b/jvm-packages/README.md
index e1dfb1576..a9aded6f8 100644
--- a/jvm-packages/README.md
+++ b/jvm-packages/README.md
@@ -61,7 +61,6 @@ object DistTrainWithSpark {
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val sc = new SparkContext(sparkConf)
- val sc = new SparkContext(sparkConf)
val inputTrainPath = args(1)
val outputModelPath = args(2)
// number of iterations
@@ -73,7 +72,8 @@ object DistTrainWithSpark {
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// use 5 distributed workers to train the model
- val model = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = 5)
+ // useExternalMemory indicates whether
+ val model = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = 5, useExternalMemory = true)
// save model to HDFS path
model.saveModelToHadoop(outputModelPath)
}
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 5d0cbd00b..c9566016b 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -23,6 +23,18 @@
xgboost4j-spark
xgboost4j-flink
+
+
+ spark-1.x
+
+ true
+
+
+ 1.6.1
+ 2.10
+
+
+
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java
index 349098ae1..a4a1cb703 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java
@@ -32,8 +32,8 @@ public class ExternalMemory {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
+ DMatrix trainMat = new DMatrix("../demo/data/agaricus.txt.train#dtrain.cache");
+ DMatrix testMat = new DMatrix("../demo/data/agaricus.txt.test#dtest.cache");
//specify parameters
HashMap params = new HashMap();
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala
index 978e8f0ee..a5ebfa05a 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/DistTrainWithSpark.scala
@@ -28,7 +28,7 @@ object DistTrainWithSpark {
"usage: program num_of_rounds num_workers training_path test_path model_path")
sys.exit(1)
}
- val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoost-spark-example")
+ val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster]))
val sc = new SparkContext(sparkConf)
@@ -45,7 +45,8 @@ object DistTrainWithSpark {
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
- val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt)
+ val xgboostModel = XGBoost.train(trainRDD, paramMap, numRound, nWorkers = args(1).toInt,
+ useExternalMemory = true)
xgboostModel.predict(new DMatrix(testSet))
// save model to HDFS path
xgboostModel.saveModelAsHadoopFile(outputModelPath)
diff --git a/jvm-packages/xgboost4j-spark/pom.xml b/jvm-packages/xgboost4j-spark/pom.xml
index ac37e78ac..120616c58 100644
--- a/jvm-packages/xgboost4j-spark/pom.xml
+++ b/jvm-packages/xgboost4j-spark/pom.xml
@@ -28,8 +28,8 @@
org.apache.spark
- spark-mllib_2.10
- 1.6.1
+ spark-mllib_${scala.binary.version}
+ ${spark.version}
\ No newline at end of file
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index dc1b5382c..5903bd2c9 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -16,6 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark
+import java.nio.file.Paths
+
import scala.collection.mutable
import scala.collection.JavaConverters._
@@ -41,7 +43,8 @@ object XGBoost extends Serializable {
trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, Any],
rabitEnv: mutable.Map[String, String],
- numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait): RDD[Booster] = {
+ numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
+ useExternalMemory: Boolean): RDD[Booster] = {
import DataUtils._
val partitionedData = {
if (numWorkers > trainingData.partitions.length) {
@@ -54,11 +57,19 @@ object XGBoost extends Serializable {
trainingData
}
}
+ val appName = partitionedData.context.appName
partitionedData.mapPartitions {
trainingSamples =>
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
- val trainingSet = new DMatrix(new JDMatrix(trainingSamples, null))
+ val cacheFileName: String = {
+ if (useExternalMemory && trainingSamples.hasNext) {
+ s"$appName-dtrain_cache-${TaskContext.getPartitionId()}"
+ } else {
+ null
+ }
+ }
+ val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
val booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
watches = new mutable.HashMap[String, DMatrix]{put("train", trainingSet)}.toMap,
obj, eval)
@@ -76,12 +87,15 @@ object XGBoost extends Serializable {
* workers equals to the partition number of trainingData RDD
* @param obj the user-defined objective function, null by default
* @param eval the user-defined evaluation function, null by default
+ * @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
+ * true, the user may save the RAM cost for running XGBoost within Spark
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
- nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null): XGBoostModel = {
+ nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
+ useExternalMemory: Boolean = false): XGBoostModel = {
require(nWorkers > 0, "you must specify more than 0 workers")
val tracker = new RabitTracker(nWorkers)
implicit val sc = trainingData.sparkContext
@@ -97,7 +111,7 @@ object XGBoost extends Serializable {
}
require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
- tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval)
+ tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory)
val sparkJobThread = new Thread() {
override def run() {
// force the job
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala
index 75a91e64c..f81e63048 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala
@@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.hadoop.fs.{Path, FileSystem}
-import org.apache.spark.SparkContext
+import org.apache.spark.{TaskContext, SparkContext}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
@@ -27,13 +27,23 @@ class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Ser
/**
* Predict result with the given testset (represented as RDD)
+ * @param testSet test set representd as RDD
+ * @param useExternalCache whether to use external cache for the test set
*/
- def predict(testSet: RDD[Vector]): RDD[Array[Array[Float]]] = {
+ def predict(testSet: RDD[Vector], useExternalCache: Boolean = false): RDD[Array[Array[Float]]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
+ val appName = testSet.context.appName
testSet.mapPartitions { testSamples =>
if (testSamples.hasNext) {
- val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
+ val cacheFileName = {
+ if (useExternalCache) {
+ s"$appName-dtest_cache-${TaskContext.getPartitionId()}"
+ } else {
+ null
+ }
+ }
+ val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
Iterator(broadcastBooster.value.predict(dMatrix))
} else {
Iterator()
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
index 711ea35f0..71bb9ecf8 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
@@ -127,7 +127,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
- numWorkers = 2, round = 5, null, null)
+ numWorkers = 2, round = 5, null, null, false)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
val boosters = boosterRDD.collect()
@@ -210,4 +210,26 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
println(xgBoostModel.predict(testRDD))
}
+
+ test("training with external memory cache") {
+ sc.stop()
+ sc = null
+ val sparkConf = new SparkConf().setMaster("local[*]").setAppName("XGBoostSuite")
+ val customSparkContext = new SparkContext(sparkConf)
+ val eval = new EvalError()
+ val trainingRDD = buildTrainingRDD(Some(customSparkContext))
+ val testSet = readFile(getClass.getResource("/agaricus.txt.test").getFile).iterator
+ import DataUtils._
+ val testSetDMatrix = new DMatrix(new JDMatrix(testSet, null))
+ val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
+ "objective" -> "binary:logistic").toMap
+ val xgBoostModel = XGBoost.train(trainingRDD, paramMap, 5, numWorkers, useExternalMemory = true)
+ assert(eval.eval(xgBoostModel.predict(testSetDMatrix), testSetDMatrix) < 0.1)
+ customSparkContext.stop()
+ // clean
+ val dir = new File(".")
+ for (file <- dir.listFiles() if file.getName.startsWith("XGBoostSuite-dtrain_cache")) {
+ file.delete()
+ }
+ }
}
diff --git a/python-package/.pylintrc b/python-package/.pylintrc
index 1e63cdabe..e8e957d2b 100644
--- a/python-package/.pylintrc
+++ b/python-package/.pylintrc
@@ -2,8 +2,8 @@
ignore=tests
-unexpected-special-method-signature,too-many-nested-blocks
+disiable=unexpected-special-method-signature,too-many-nested-blocks
dummy-variables-rgx=(unused|)_.*
-reports=no
\ No newline at end of file
+reports=no
diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py
new file mode 100644
index 000000000..3683ea2dd
--- /dev/null
+++ b/python-package/xgboost/callback.py
@@ -0,0 +1,217 @@
+# coding: utf-8
+# pylint: disable= invalid-name
+"""Training Library containing training routines."""
+from __future__ import absolute_import
+
+from . import rabit
+from .core import EarlyStopException
+
+
+def _fmt_metric(value, show_stdv=True):
+ """format metric string"""
+ if len(value) == 2:
+ return '%s:%g' % (value[0], value[1])
+ elif len(value) == 3:
+ if show_stdv:
+ return '%s:%g+%g' % (value[0], value[1], value[2])
+ else:
+ return '%s:%g' % (value[0], value[1])
+ else:
+ raise ValueError("wrong metric value")
+
+
+def print_evaluation(period=1, show_stdv=True):
+ """Create a callback that print evaluation result.
+
+ Parameters
+ ----------
+ period : int
+ The period to log the evaluation results
+
+ show_stdv : bool, optional
+ Whether show stdv if provided
+
+ Returns
+ -------
+ callback : function
+ A callback that print evaluation every period iterations.
+ """
+ def callback(env):
+ """internal function"""
+ if env.rank != 0 or len(env.evaluation_result_list) == 0:
+ return
+ i = env.iteration
+ if (i % period == 0 or i + 1 == env.begin_iteration):
+ msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
+ rabit.tracker_print('[%d]\t%s\n' % (i, msg))
+ return callback
+
+
+def record_evaluation(eval_result):
+ """Create a call back that records the evaluation history into eval_result.
+
+ Parameters
+ ----------
+ eval_result : dict
+ A dictionary to store the evaluation results.
+
+ Returns
+ -------
+ callback : function
+ The requested callback function.
+ """
+ if not isinstance(eval_result, dict):
+ raise TypeError('eval_result has to be a dictionary')
+ eval_result.clear()
+
+ def init(env):
+ """internal function"""
+ for k, _ in env.evaluation_result_list:
+ key, metric = k.split('-')
+ if key not in eval_result:
+ eval_result[key] = {}
+ if metric not in eval_result[key]:
+ eval_result[key][metric] = []
+
+ def callback(env):
+ """internal function"""
+ if len(eval_result) == 0:
+ init(env)
+ for k, v in env.evaluation_result_list:
+ key, metric = k.split('-')
+ eval_result[key][metric].append(v)
+ return callback
+
+
+def reset_learning_rate(learning_rates):
+ """Reset learning rate after iteration 1
+
+ NOTE: the initial learning rate will still take in-effect on first iteration.
+
+ Parameters
+ ----------
+ learning_rates: list or function
+ List of learning rate for each boosting round
+ or a customized function that calculates eta in terms of
+ current number of round and the total number of boosting round (e.g. yields
+ learning rate decay)
+ - list l: eta = l[boosting round]
+ - function f: eta = f(boosting round, num_boost_round)
+
+ Returns
+ -------
+ callback : function
+ The requested callback function.
+ """
+ def callback(env):
+ """internal function"""
+ bst = env.model
+ i = env.iteration
+ if isinstance(learning_rates, list):
+ if len(learning_rates) != env.end_iteration:
+ raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
+ bst.set_param('learning_rate', learning_rates[i])
+ else:
+ bst.set_param('learning_rate', learning_rates(i, env.end_iteration))
+ callback.before_iteration = True
+ return callback
+
+
+def early_stop(stopping_rounds, maximize=False, verbose=True):
+ """Create a callback that activates early stoppping.
+
+ Validation error needs to decrease at least
+ every round(s) to continue training.
+ Requires at least one item in evals.
+ If there's more than one, will use the last.
+ Returns the model from the last iteration (not the best one).
+ If early stopping occurs, the model will have three additional fields:
+ bst.best_score, bst.best_iteration and bst.best_ntree_limit.
+ (Use bst.best_ntree_limit to get the correct value if num_parallel_tree
+ and/or num_class appears in the parameters)
+
+ Parameters
+ ----------
+ stopp_rounds : int
+ The stopping rounds before the trend occur.
+
+ maximize : bool
+ Whether to maximize evaluation metric.
+
+ verbose : optional, bool
+ Whether to print message about early stopping information.
+
+ Returns
+ -------
+ callback : function
+ The requested callback function.
+ """
+ state = {}
+
+ def init(env):
+ """internal function"""
+ bst = env.model
+
+ if len(env.evaluation_result_list) == 0:
+ raise ValueError('For early stopping you need at least one set in evals.')
+ if len(env.evaluation_result_list) > 1 and verbose:
+ msg = ("Multiple eval metrics have been passed: "
+ "'{0}' will be used for early stopping.\n\n")
+ rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
+ maximize_metrics = ('auc', 'map', 'ndcg')
+ maximize_score = maximize
+ metric = env.evaluation_result_list[-1][0]
+ if any(env.evaluation_result_list[-1][0].split('-')[1].startswith(x)
+ for x in maximize_metrics):
+ maximize_score = True
+
+ if verbose and env.rank == 0:
+ msg = "Will train until {} hasn't improved in {} rounds.\n"
+ rabit.tracker_print(msg.format(metric, stopping_rounds))
+
+ state['maximize_score'] = maximize_score
+ state['best_iteration'] = 0
+ if maximize_score:
+ state['best_score'] = float('-inf')
+ else:
+ state['best_score'] = float('inf')
+
+ if bst is not None:
+ if bst.attr('best_score') is not None:
+ state['best_score'] = float(bst.attr('best_score'))
+ state['best_iteration'] = int(bst.attr('best_iteration'))
+ state['best_msg'] = bst.attr('best_msg')
+ else:
+ bst.set_attr(best_iteration=str(state['best_iteration']))
+ bst.set_attr(best_score=str(state['best_score']))
+ else:
+ assert env.cvfolds is not None
+
+ def callback(env):
+ """internal function"""
+ score = env.evaluation_result_list[-1][1]
+ if len(state) == 0:
+ init(env)
+ best_score = state['best_score']
+ best_iteration = state['best_iteration']
+ maximize_score = state['maximize_score']
+ if (maximize_score and score > best_score) or \
+ (not maximize_score and score < best_score):
+ msg = '[%d]\t%s' % (
+ env.iteration,
+ '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
+ state['best_msg'] = msg
+ state['best_score'] = score
+ state['best_iteration'] = env.iteration
+ # save the property to attributes, so they will occur in checkpoint.
+ if env.model is not None:
+ env.model.set_attr(best_score=str(state['best_score']),
+ best_iteration=str(state['best_iteration']),
+ best_msg=state['best_msg'])
+ elif env.iteration - best_iteration >= stopping_rounds:
+ best_msg = state['best_msg']
+ if verbose and env.rank == 0:
+ msg = "Stopping. Best iteration:\n{}\n\n"
+ rabit.tracker_print(msg.format(best_msg))
+ raise EarlyStopException(best_iteration)
+ return callback
diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py
index 44707c539..8237b1249 100644
--- a/python-package/xgboost/compat.py
+++ b/python-package/xgboost/compat.py
@@ -1,5 +1,5 @@
# coding: utf-8
-# pylint: disable=unused-import, invalid-name, wrong-import-position
+# pylint: disable= invalid-name, unused-import
"""For compatibility"""
from __future__ import absolute_import
@@ -14,12 +14,14 @@ if PY3:
STRING_TYPES = str,
def py_str(x):
+ """convert c string back to python string"""
return x.decode('utf-8')
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,
def py_str(x):
+ """convert c string back to python string"""
return x
try:
diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py
index f22ca7ef1..e31f622cf 100644
--- a/python-package/xgboost/core.py
+++ b/python-package/xgboost/core.py
@@ -1,5 +1,6 @@
# coding: utf-8
-# pylint: disable=too-many-arguments, too-many-branches
+# pylint: disable=too-many-arguments, too-many-branches, invalid-name
+# pylint: disable=too-many-branches, too-many-lines, W0141
"""Core XGBoost Library."""
from __future__ import absolute_import
@@ -22,6 +23,31 @@ class XGBoostError(Exception):
pass
+class EarlyStopException(Exception):
+ """Exception to signal early stopping.
+
+ Parameters
+ ----------
+ best_iteration : int
+ The best iteration stopped.
+ """
+ def __init__(self, best_iteration):
+ super(EarlyStopException, self).__init__()
+ self.best_iteration = best_iteration
+
+
+# Callback environment used by callbacks
+CallbackEnv = collections.namedtuple(
+ "XGBoostCallbackEnv",
+ ["model",
+ "cvfolds",
+ "iteration",
+ "begin_iteration",
+ "end_iteration",
+ "rank",
+ "evaluation_result_list"])
+
+
def from_pystr_to_cstr(data):
"""Convert a list of Python str to C pointer
@@ -657,7 +683,7 @@ class Booster(object):
def __copy__(self):
return self.__deepcopy__(None)
- def __deepcopy__(self, memo):
+ def __deepcopy__(self, _):
return Booster(model_file=self.save_raw())
def copy(self):
@@ -975,7 +1001,6 @@ class Booster(object):
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
def dump_model(self, fout, fmap='', with_stats=False):
- # pylint: disable=consider-using-enumerate
"""
Dump model into a text file.
@@ -1143,10 +1168,12 @@ class Booster(object):
msg = 'feature_names mismatch: {0} {1}'
if dat_missing:
- msg += '\nexpected ' + ', '.join(str(s) for s in dat_missing) + ' in input data'
+ msg += ('\nexpected ' + ', '.join(str(s) for s in dat_missing) +
+ ' in input data')
if my_missing:
- msg += '\ntraining data did not have the following fields: ' + ', '.join(str(s) for s in my_missing)
+ msg += ('\ntraining data did not have the following fields: ' +
+ ', '.join(str(s) for s in my_missing))
raise ValueError(msg.format(self.feature_names,
data.feature_names))
@@ -1161,23 +1188,25 @@ class Booster(object):
The name of feature map file.
bin: int, default None
The maximum number of bins.
- Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
+ Number of bins equals number of unique split values n_unique,
+ if bins == None or bins > n_unique.
as_pandas : bool, default True
Return pd.DataFrame when pandas is installed.
If False or pandas is not installed, return numpy ndarray.
Returns
-------
- a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
+ a histogram of used splitting values for the specified feature
+ either as numpy array or pandas DataFrame.
"""
xgdump = self.get_dump(fmap=fmap)
values = []
- regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
+ regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
for i in range(len(xgdump)):
m = re.findall(regexp, xgdump[i])
values.extend(map(float, m))
- n_unique = np.unique(values).shape[0]
+ n_unique = len(np.unique(values))
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
nph = np.histogram(values, bins=bins)
@@ -1187,7 +1216,8 @@ class Booster(object):
if as_pandas and PANDAS_INSTALLED:
return DataFrame(nph, columns=['SplitValue', 'Count'])
elif as_pandas and not PANDAS_INSTALLED:
- sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
+ sys.stderr.write(
+ "Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
return nph
else:
return nph
diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py
index af85b2dd0..89b2a4ec6 100644
--- a/python-package/xgboost/rabit.py
+++ b/python-package/xgboost/rabit.py
@@ -1,3 +1,6 @@
+# coding: utf-8
+# pylint: disable= invalid-name
+
"""Distributed XGBoost Rabit related API."""
from __future__ import absolute_import
import sys
@@ -179,7 +182,7 @@ def allreduce(data, op, prepare_fun=None):
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
- def pfunc(args):
+ def pfunc(_):
"""prepare function."""
prepare_fun(data)
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py
index cafbe073f..2b4c2accb 100644
--- a/python-package/xgboost/sklearn.py
+++ b/python-package/xgboost/sklearn.py
@@ -1,5 +1,5 @@
# coding: utf-8
-# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme
+# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912
"""Scikit-Learn Wrapper interface for XGBoost."""
from __future__ import absolute_import
@@ -42,6 +42,7 @@ def _objective_decorator(func):
``dmatrix.get_label()``
"""
def inner(preds, dmatrix):
+ """internal function"""
labels = dmatrix.get_label()
return func(labels, preds)
return inner
@@ -79,9 +80,9 @@ class XGBModel(XGBModelBase):
colsample_bylevel : float
Subsample ratio of columns for each split, in each level.
reg_alpha : float (xgb's alpha)
- L2 regularization term on weights
- reg_lambda : float (xgb's lambda)
L1 regularization term on weights
+ reg_lambda : float (xgb's lambda)
+ L2 regularization term on weights
scale_pos_weight : float
Balancing of positive and negative weights.
@@ -183,7 +184,7 @@ class XGBModel(XGBModelBase):
def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
- # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init, redefined-variable-type
+ # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
@@ -351,7 +352,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
- # pylint: disable = attribute-defined-outside-init,arguments-differ, redefined-variable-type
+ # pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
@@ -440,6 +441,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose)
+ self.objective = xgb_options["objective"]
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]
diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py
index d21edd30d..3da92ff51 100644
--- a/python-package/xgboost/training.py
+++ b/python-package/xgboost/training.py
@@ -1,20 +1,122 @@
# coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
-# pylint: disable=too-many-branches
+# pylint: disable=too-many-branches, too-many-statements
"""Training Library containing training routines."""
from __future__ import absolute_import
-import sys
-import re
+
import numpy as np
-from .core import Booster, STRING_TYPES, XGBoostError
+from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import rabit
+from . import callback
+
+
+def _train_internal(params, dtrain,
+ num_boost_round=10, evals=(),
+ obj=None, feval=None,
+ xgb_model=None, callbacks=None):
+ """internal training function"""
+ callbacks = [] if callbacks is None else callbacks
+ evals = list(evals)
+ if isinstance(params, dict) \
+ and 'eval_metric' in params \
+ and isinstance(params['eval_metric'], list):
+ params = dict((k, v) for k, v in params.items())
+ eval_metrics = params['eval_metric']
+ params.pop("eval_metric", None)
+ params = list(params.items())
+ for eval_metric in eval_metrics:
+ params += [('eval_metric', eval_metric)]
+
+ bst = Booster(params, [dtrain] + [d[0] for d in evals])
+ nboost = 0
+ num_parallel_tree = 1
+
+ if xgb_model is not None:
+ if not isinstance(xgb_model, STRING_TYPES):
+ xgb_model = xgb_model.save_raw()
+ bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
+ nboost = len(bst.get_dump())
+ else:
+ bst = Booster(params, [dtrain] + [d[0] for d in evals])
+
+ _params = dict(params) if isinstance(params, list) else params
+
+ if 'num_parallel_tree' in _params:
+ num_parallel_tree = _params['num_parallel_tree']
+ nboost //= num_parallel_tree
+ if 'num_class' in _params:
+ nboost //= _params['num_class']
+
+ # Distributed code: Load the checkpoint from rabit.
+ version = bst.load_rabit_checkpoint()
+ assert(rabit.get_world_size() != 1 or version == 0)
+ rank = rabit.get_rank()
+ start_iteration = int(version / 2)
+ nboost += start_iteration
+
+ callbacks_before_iter = [
+ cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
+ callbacks_after_iter = [
+ cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
+
+ for i in range(start_iteration, num_boost_round):
+ for cb in callbacks_before_iter:
+ cb(CallbackEnv(model=bst,
+ cvfolds=None,
+ iteration=i,
+ begin_iteration=start_iteration,
+ end_iteration=num_boost_round,
+ rank=rank,
+ evaluation_result_list=None))
+ # Distributed code: need to resume to this point.
+ # Skip the first update if it is a recovery step.
+ if version % 2 == 0:
+ bst.update(dtrain, i, obj)
+ bst.save_rabit_checkpoint()
+ version += 1
+
+ assert(rabit.get_world_size() == 1 or version == rabit.version_number())
+
+ nboost += 1
+ evaluation_result_list = []
+ # check evaluation result.
+ if len(evals) != 0:
+ bst_eval_set = bst.eval_set(evals, i, feval)
+ if isinstance(bst_eval_set, STRING_TYPES):
+ msg = bst_eval_set
+ else:
+ msg = bst_eval_set.decode()
+ res = [x.split(':') for x in msg.split()]
+ evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
+ try:
+ for cb in callbacks_after_iter:
+ cb(CallbackEnv(model=bst,
+ cvfolds=None,
+ iteration=i,
+ begin_iteration=start_iteration,
+ end_iteration=num_boost_round,
+ rank=rank,
+ evaluation_result_list=evaluation_result_list))
+ except EarlyStopException:
+ break
+ # do checkpoint after evaluation, in case evaluation also updates booster.
+ bst.save_rabit_checkpoint()
+ version += 1
+
+ if bst.attr('best_score') is not None:
+ bst.best_score = float(bst.attr('best_score'))
+ bst.best_iteration = int(bst.attr('best_iteration'))
+ else:
+ bst.best_iteration = nboost - 1
+ bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
+ return bst
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize=False, early_stopping_rounds=None, evals_result=None,
- verbose_eval=True, learning_rates=None, xgb_model=None):
+ verbose_eval=True, learning_rates=None, xgb_model=None, callbacks=None):
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
"""Train a booster with given parameters.
@@ -70,176 +172,37 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
xgb_model : file name of stored xgb model or 'Booster' instance
Xgb model to be loaded before training (allows training continuation).
+ callbacks : list of callback functions
+ List of callback functions that are applied at end of each iteration.
+
Returns
-------
booster : a trained booster model
"""
- evals = list(evals)
- if isinstance(params, dict) \
- and 'eval_metric' in params \
- and isinstance(params['eval_metric'], list):
- params = dict((k, v) for k, v in params.items())
- eval_metrics = params['eval_metric']
- params.pop("eval_metric", None)
- params = list(params.items())
- for eval_metric in eval_metrics:
- params += [('eval_metric', eval_metric)]
+ callbacks = [] if callbacks is None else callbacks
- bst = Booster(params, [dtrain] + [d[0] for d in evals])
- nboost = 0
- num_parallel_tree = 1
-
- if isinstance(verbose_eval, bool):
- verbose_eval_every_line = False
+ # Most of legacy advanced options becomes callbacks
+ if isinstance(verbose_eval, bool) and verbose_eval:
+ callbacks.append(callback.print_evaluation())
else:
if isinstance(verbose_eval, int):
- verbose_eval_every_line = verbose_eval
- verbose_eval = True if verbose_eval_every_line > 0 else False
+ callbacks.append(callback.print_evaluation(verbose_eval))
- if rabit.get_rank() != 0:
- verbose_eval = False
-
- if xgb_model is not None:
- if not isinstance(xgb_model, STRING_TYPES):
- xgb_model = xgb_model.save_raw()
- bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
- nboost = len(bst.get_dump())
- else:
- bst = Booster(params, [dtrain] + [d[0] for d in evals])
-
- _params = dict(params) if isinstance(params, list) else params
- _eta_param_name = 'eta' if 'eta' in _params else 'learning_rate'
- if 'num_parallel_tree' in _params:
- num_parallel_tree = _params['num_parallel_tree']
- nboost //= num_parallel_tree
- if 'num_class' in _params:
- nboost //= _params['num_class']
+ if early_stopping_rounds is not None:
+ callbacks.append(callback.early_stop(early_stopping_rounds,
+ maximize=maximize,
+ verbose=bool(verbose_eval)))
+ if learning_rates is not None:
+ callbacks.append(callback.reset_learning_rate(learning_rates))
if evals_result is not None:
- if not isinstance(evals_result, dict):
- raise TypeError('evals_result has to be a dictionary')
- else:
- evals_name = [d[1] for d in evals]
- evals_result.clear()
- evals_result.update(dict([(key, {}) for key in evals_name]))
+ callbacks.append(callback.record_evaluation(evals_result))
- # early stopping
- if early_stopping_rounds is not None:
- if len(evals) < 1:
- raise ValueError('For early stopping you need at least one set in evals.')
-
- if verbose_eval:
- rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format(
- evals[-1][1], early_stopping_rounds))
-
- # is params a list of tuples? are we using multiple eval metrics?
- if isinstance(params, list):
- if len(params) != len(dict(params).items()):
- params = dict(params)
- msg = ("Multiple eval metrics have been passed: "
- "'{0}' will be used for early stopping.\n\n")
- rabit.tracker_print(msg.format(params['eval_metric']))
- else:
- params = dict(params)
-
- # either minimize loss or maximize AUC/MAP/NDCG
- maximize_score = False
- if 'eval_metric' in params:
- maximize_metrics = ('auc', 'map', 'ndcg')
- if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
- maximize_score = True
- if feval is not None:
- maximize_score = maximize
-
- if maximize_score:
- bst.set_attr(best_score='0.0')
- else:
- bst.set_attr(best_score='inf')
- bst.set_attr(best_iteration='0')
-
- if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
- raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
-
- # Distributed code: Load the checkpoint from rabit.
- version = bst.load_rabit_checkpoint()
- assert(rabit.get_world_size() != 1 or version == 0)
- start_iteration = int(version / 2)
- nboost += start_iteration
-
- for i in range(start_iteration, num_boost_round):
- if learning_rates is not None:
- if isinstance(learning_rates, list):
- bst.set_param(_eta_param_name, learning_rates[i])
- else:
- bst.set_param(_eta_param_name, learning_rates(i, num_boost_round))
-
- # Distributed code: need to resume to this point.
- # Skip the first update if it is a recovery step.
- if version % 2 == 0:
- bst.update(dtrain, i, obj)
- bst.save_rabit_checkpoint()
- version += 1
-
- assert(rabit.get_world_size() == 1 or version == rabit.version_number())
-
- nboost += 1
- # check evaluation result.
- if len(evals) != 0:
- bst_eval_set = bst.eval_set(evals, i, feval)
-
- if isinstance(bst_eval_set, STRING_TYPES):
- msg = bst_eval_set
- else:
- msg = bst_eval_set.decode()
-
- if verbose_eval:
- if verbose_eval_every_line:
- if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
- rabit.tracker_print(msg + '\n')
- else:
- rabit.tracker_print(msg + '\n')
-
- if evals_result is not None:
- res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
- for key in evals_name:
- evals_idx = evals_name.index(key)
- res_per_eval = len(res) // len(evals_name)
- for r in range(res_per_eval):
- res_item = res[(evals_idx * res_per_eval) + r]
- res_key = res_item[0]
- res_val = res_item[1]
- if res_key in evals_result[key]:
- evals_result[key][res_key].append(res_val)
- else:
- evals_result[key][res_key] = [res_val]
-
- if early_stopping_rounds:
- score = float(msg.rsplit(':', 1)[1])
- best_score = float(bst.attr('best_score'))
- best_iteration = int(bst.attr('best_iteration'))
- if (maximize_score and score > best_score) or \
- (not maximize_score and score < best_score):
- # save the property to attributes, so they will occur in checkpoint.
- bst.set_attr(best_score=str(score),
- best_iteration=str(nboost - 1),
- best_msg=msg)
- elif i - best_iteration >= early_stopping_rounds:
- best_msg = bst.attr('best_msg')
- if verbose_eval:
- msg = "Stopping. Best iteration:\n{}\n\n"
- rabit.tracker_print(msg.format(best_msg))
- break
- # do checkpoint after evaluation, in case evaluation also updates booster.
- bst.save_rabit_checkpoint()
- version += 1
-
- if early_stopping_rounds:
- bst.best_score = float(bst.attr('best_score'))
- bst.best_iteration = int(bst.attr('best_iteration'))
- else:
- bst.best_iteration = nboost - 1
- bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
- return bst
+ return _train_internal(params, dtrain,
+ num_boost_round=num_boost_round,
+ evals=evals,
+ obj=obj, feval=feval,
+ xgb_model=xgb_model, callbacks=callbacks)
class CVPack(object):
@@ -294,7 +257,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
return ret
-def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0):
+def aggcv(rlist):
# pylint: disable=invalid-name
"""
Aggregate cross-validation results.
@@ -315,50 +278,21 @@ def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0):
if k not in cvmap:
cvmap[k] = []
cvmap[k].append(float(v))
-
msg = idx
-
- if show_stdv:
- fmt = '\tcv-{0}:{1}+{2}'
- else:
- fmt = '\tcv-{0}:{1}'
-
- index = []
results = []
- for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
+ for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])):
v = np.array(v)
if not isinstance(msg, STRING_TYPES):
msg = msg.decode()
mean, std = np.mean(v), np.std(v)
- msg += fmt.format(k, mean, std)
-
- index.extend([k + '-mean', k + '-std'])
- results.extend([mean, std])
-
- if as_pandas:
- try:
- import pandas as pd
- results = pd.Series(results, index=index)
- except ImportError:
- if verbose_eval is None:
- verbose_eval = True
- else:
- # if verbose_eval is default (None),
- # result will be np.ndarray as it can't hold column name
- if verbose_eval is None:
- verbose_eval = True
-
- if (isinstance(verbose_eval, int) and verbose_eval > 0 and trial % verbose_eval == 0) or \
- (isinstance(verbose_eval, bool) and verbose_eval):
- sys.stderr.write(msg + '\n')
- sys.stderr.flush()
-
+ results.extend([(k, mean, std)])
return results
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
- fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0):
+ fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0,
+ callbacks=None):
# pylint: disable = invalid-name
"""Cross-validation with given paramaters.
@@ -404,6 +338,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
Results are not affected, and always contains std.
seed : int
Seed used to generate the folds (passed to numpy.random.seed).
+ callbacks : list of callback functions
+ List of callback functions that are applied at end of each iteration.
Returns
-------
@@ -431,59 +367,63 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
params.pop("eval_metric", None)
- if early_stopping_rounds is not None:
-
- if len(metrics) > 1:
- msg = ('Check your params. '
- 'Early stopping works with single eval metric only.')
- raise ValueError(msg)
- if verbose_eval:
- msg = "Will train until cv error hasn't decreased in {} rounds.\n"
- sys.stderr.write(msg.format(early_stopping_rounds))
-
- maximize_score = False
- if len(metrics) == 1:
- maximize_metrics = ('auc', 'map', 'ndcg')
- if any(metrics[0].startswith(x) for x in maximize_metrics):
- maximize_score = True
- if feval is not None:
- maximize_score = maximize
-
- if maximize_score:
- best_score = 0.0
- else:
- best_score = float('inf')
-
- best_score_i = 0
- results = []
+ results = {}
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds)
+
+ # setup callbacks
+ callbacks = [] if callbacks is None else callbacks
+ if early_stopping_rounds is not None:
+ callbacks.append(callback.early_stop(early_stopping_rounds,
+ maximize=maximize,
+ verbose=False))
+ if isinstance(verbose_eval, bool) and verbose_eval:
+ callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
+ else:
+ if isinstance(verbose_eval, int):
+ callbacks.append(callback.print_evaluation(verbose_eval, show_stdv=show_stdv))
+
+ callbacks_before_iter = [
+ cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
+ callbacks_after_iter = [
+ cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
+
for i in range(num_boost_round):
+ for cb in callbacks_before_iter:
+ cb(CallbackEnv(model=None,
+ cvfolds=cvfolds,
+ iteration=i,
+ begin_iteration=0,
+ end_iteration=num_boost_round,
+ rank=0,
+ evaluation_result_list=None))
for fold in cvfolds:
fold.update(i, obj)
- res = aggcv([f.eval(i, feval) for f in cvfolds],
- show_stdv=show_stdv, verbose_eval=verbose_eval,
- as_pandas=as_pandas, trial=i)
- results.append(res)
+ res = aggcv([f.eval(i, feval) for f in cvfolds])
- if early_stopping_rounds is not None:
- score = res[0]
- if (maximize_score and score > best_score) or \
- (not maximize_score and score < best_score):
- best_score = score
- best_score_i = i
- elif i - best_score_i >= early_stopping_rounds:
- results = results[:best_score_i + 1]
- if verbose_eval:
- msg = "Stopping. Best iteration:\n[{}] cv-mean:{}\tcv-std:{}\n"
- sys.stderr.write(msg.format(best_score_i, results[-1][0], results[-1][1]))
- break
+ for key, mean, std in res:
+ if key + '-mean' not in results:
+ results[key + '-mean'] = []
+ if key + '-std' not in results:
+ results[key + '-std'] = []
+ results[key + '-mean'].append(mean)
+ results[key + '-std'].append(std)
+ try:
+ for cb in callbacks_after_iter:
+ cb(CallbackEnv(model=None,
+ cvfolds=cvfolds,
+ iteration=i,
+ begin_iteration=0,
+ end_iteration=num_boost_round,
+ rank=0,
+ evaluation_result_list=res))
+ except EarlyStopException as e:
+ for k in results.keys():
+ results[k] = results[k][:(e.best_iteration + 1)]
+ break
if as_pandas:
try:
import pandas as pd
- results = pd.DataFrame(results)
+ results = pd.DataFrame.from_dict(results)
except ImportError:
- results = np.array(results)
- else:
- results = np.array(results)
-
+ pass
return results
diff --git a/rabit b/rabit
index e19fced5c..8f61535b8 160000
--- a/rabit
+++ b/rabit
@@ -1 +1 @@
-Subproject commit e19fced5cbd4e41b10099facae7caa5cd3e6ada3
+Subproject commit 8f61535b83e650331459d7f33a1615fa7d27b7bd
diff --git a/src/cli_main.cc b/src/cli_main.cc
index 039ce4070..e79592615 100644
--- a/src/cli_main.cc
+++ b/src/cli_main.cc
@@ -271,6 +271,7 @@ void CLIDump2Text(const CLIParam& param) {
std::unique_ptr learner(Learner::Create({}));
std::unique_ptr fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
+ learner->Configure(param.cfg);
learner->Load(fi.get());
// dump data
std::vector dump = learner->Dump2Text(fmap, param.dump_stats);
@@ -297,6 +298,7 @@ void CLIPredict(const CLIParam& param) {
std::unique_ptr learner(Learner::Create({}));
std::unique_ptr fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
+ learner->Configure(param.cfg);
learner->Load(fi.get());
if (param.silent == 0) {
diff --git a/src/data/data.cc b/src/data/data.cc
index e8135692a..b1299a53c 100644
--- a/src/data/data.cc
+++ b/src/data/data.cc
@@ -223,6 +223,10 @@ DMatrix* DMatrix::Load(const std::string& uri,
LOG(CONSOLE) << info.base_margin.size()
<< " base_margin are loaded from " << fname << ".base_margin";
}
+ if (MetaTryLoadFloatInfo(fname + ".weight", &info.weights) && !silent) {
+ LOG(CONSOLE) << info.weights.size()
+ << " weights are loaded from " << fname << ".weight";
+ }
}
return dmat;
}
diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc
index d25e06492..74a85e9ca 100644
--- a/src/data/sparse_page_dmatrix.cc
+++ b/src/data/sparse_page_dmatrix.cc
@@ -256,41 +256,44 @@ void SparsePageDMatrix::InitColAccess(const std::vector& enabled,
name_shards.push_back(prefix + ".col.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).second);
}
- SparsePage::Writer writer(name_shards, format_shards, 6);
- std::unique_ptr page;
- writer.Alloc(&page); page->Clear();
- double tstart = dmlc::GetTime();
- size_t bytes_write = 0;
- // print every 4 sec.
- const double kStep = 4.0;
- size_t tick_expected = kStep;
+ {
+ SparsePage::Writer writer(name_shards, format_shards, 6);
+ std::unique_ptr page;
+ writer.Alloc(&page); page->Clear();
- while (make_next_col(page.get())) {
- for (size_t i = 0; i < page->Size(); ++i) {
- col_size_[i] += page->offset[i + 1] - page->offset[i];
- }
-
- bytes_write += page->MemCostBytes();
- writer.PushWrite(std::move(page));
- writer.Alloc(&page);
- page->Clear();
-
- double tdiff = dmlc::GetTime() - tstart;
- if (tdiff >= tick_expected) {
- LOG(CONSOLE) << "Writing col.page file to " << cache_info_
- << " in " << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
- << (bytes_write >> 20UL) << " MB writen";
- tick_expected += kStep;
+ double tstart = dmlc::GetTime();
+ size_t bytes_write = 0;
+ // print every 4 sec.
+ const double kStep = 4.0;
+ size_t tick_expected = kStep;
+
+ while (make_next_col(page.get())) {
+ for (size_t i = 0; i < page->Size(); ++i) {
+ col_size_[i] += page->offset[i + 1] - page->offset[i];
+ }
+
+ bytes_write += page->MemCostBytes();
+ writer.PushWrite(std::move(page));
+ writer.Alloc(&page);
+ page->Clear();
+
+ double tdiff = dmlc::GetTime() - tstart;
+ if (tdiff >= tick_expected) {
+ LOG(CONSOLE) << "Writing col.page file to " << cache_info_
+ << " in " << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
+ << (bytes_write >> 20UL) << " MB writen";
+ tick_expected += kStep;
+ }
}
+ // save meta data
+ std::string col_meta_name = cache_shards[0] + ".col.meta";
+ std::unique_ptr fo(
+ dmlc::Stream::Create(col_meta_name.c_str(), "w"));
+ fo->Write(buffered_rowset_);
+ fo->Write(col_size_);
+ fo.reset(nullptr);
}
- // save meta data
- std::string col_meta_name = cache_shards[0] + ".col.meta";
- std::unique_ptr fo(
- dmlc::Stream::Create(col_meta_name.c_str(), "w"));
- fo->Write(buffered_rowset_);
- fo->Write(col_size_);
- fo.reset(nullptr);
// initialize column data
CHECK(TryInitColData());
}
diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc
index 3f499739f..ab5d7e650 100644
--- a/src/data/sparse_page_source.cc
+++ b/src/data/sparse_page_source.cc
@@ -110,58 +110,60 @@ void SparsePageSource::Create(dmlc::Parser* src,
name_shards.push_back(prefix + ".row.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).first);
}
- SparsePage::Writer writer(name_shards, format_shards, 6);
- std::unique_ptr page;
- writer.Alloc(&page); page->Clear();
+ {
+ SparsePage::Writer writer(name_shards, format_shards, 6);
+ std::unique_ptr page;
+ writer.Alloc(&page); page->Clear();
- MetaInfo info;
- size_t bytes_write = 0;
- double tstart = dmlc::GetTime();
- // print every 4 sec.
- const double kStep = 4.0;
- size_t tick_expected = kStep;
+ MetaInfo info;
+ size_t bytes_write = 0;
+ double tstart = dmlc::GetTime();
+ // print every 4 sec.
+ const double kStep = 4.0;
+ size_t tick_expected = kStep;
- while (src->Next()) {
- const dmlc::RowBlock& batch = src->Value();
- if (batch.label != nullptr) {
- info.labels.insert(info.labels.end(), batch.label, batch.label + batch.size);
- }
- if (batch.weight != nullptr) {
- info.weights.insert(info.weights.end(), batch.weight, batch.weight + batch.size);
- }
- info.num_row += batch.size;
- info.num_nonzero += batch.offset[batch.size] - batch.offset[0];
- for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
- uint32_t index = batch.index[i];
- info.num_col = std::max(info.num_col,
- static_cast(index + 1));
- }
- page->Push(batch);
- if (page->MemCostBytes() >= kPageSize) {
- bytes_write += page->MemCostBytes();
- writer.PushWrite(std::move(page));
- writer.Alloc(&page);
- page->Clear();
+ while (src->Next()) {
+ const dmlc::RowBlock& batch = src->Value();
+ if (batch.label != nullptr) {
+ info.labels.insert(info.labels.end(), batch.label, batch.label + batch.size);
+ }
+ if (batch.weight != nullptr) {
+ info.weights.insert(info.weights.end(), batch.weight, batch.weight + batch.size);
+ }
+ info.num_row += batch.size;
+ info.num_nonzero += batch.offset[batch.size] - batch.offset[0];
+ for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
+ uint32_t index = batch.index[i];
+ info.num_col = std::max(info.num_col,
+ static_cast(index + 1));
+ }
+ page->Push(batch);
+ if (page->MemCostBytes() >= kPageSize) {
+ bytes_write += page->MemCostBytes();
+ writer.PushWrite(std::move(page));
+ writer.Alloc(&page);
+ page->Clear();
- double tdiff = dmlc::GetTime() - tstart;
- if (tdiff >= tick_expected) {
- LOG(CONSOLE) << "Writing row.page to " << cache_info << " in "
- << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
- << (bytes_write >> 20UL) << " written";
- tick_expected += kStep;
+ double tdiff = dmlc::GetTime() - tstart;
+ if (tdiff >= tick_expected) {
+ LOG(CONSOLE) << "Writing row.page to " << cache_info << " in "
+ << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
+ << (bytes_write >> 20UL) << " written";
+ tick_expected += kStep;
+ }
}
}
- }
- if (page->data.size() != 0) {
- writer.PushWrite(std::move(page));
- }
+ if (page->data.size() != 0) {
+ writer.PushWrite(std::move(page));
+ }
- std::unique_ptr fo(
- dmlc::Stream::Create(name_info.c_str(), "w"));
- int tmagic = kMagic;
- fo->Write(&tmagic, sizeof(tmagic));
- info.SaveBinary(fo.get());
+ std::unique_ptr fo(
+ dmlc::Stream::Create(name_info.c_str(), "w"));
+ int tmagic = kMagic;
+ fo->Write(&tmagic, sizeof(tmagic));
+ info.SaveBinary(fo.get());
+ }
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;
}
@@ -176,38 +178,39 @@ void SparsePageSource::Create(DMatrix* src,
name_shards.push_back(prefix + ".row.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).first);
}
- SparsePage::Writer writer(name_shards, format_shards, 6);
- std::unique_ptr page;
- writer.Alloc(&page); page->Clear();
+ {
+ SparsePage::Writer writer(name_shards, format_shards, 6);
+ std::unique_ptr page;
+ writer.Alloc(&page); page->Clear();
- MetaInfo info;
- size_t bytes_write = 0;
- double tstart = dmlc::GetTime();
- dmlc::DataIter* iter = src->RowIterator();
+ MetaInfo info;
+ size_t bytes_write = 0;
+ double tstart = dmlc::GetTime();
+ dmlc::DataIter* iter = src->RowIterator();
- while (iter->Next()) {
- page->Push(iter->Value());
- if (page->MemCostBytes() >= kPageSize) {
- bytes_write += page->MemCostBytes();
- writer.PushWrite(std::move(page));
- writer.Alloc(&page);
- page->Clear();
- double tdiff = dmlc::GetTime() - tstart;
- LOG(CONSOLE) << "Writing to " << cache_info << " in "
- << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
- << (bytes_write >> 20UL) << " written";
+ while (iter->Next()) {
+ page->Push(iter->Value());
+ if (page->MemCostBytes() >= kPageSize) {
+ bytes_write += page->MemCostBytes();
+ writer.PushWrite(std::move(page));
+ writer.Alloc(&page);
+ page->Clear();
+ double tdiff = dmlc::GetTime() - tstart;
+ LOG(CONSOLE) << "Writing to " << cache_info << " in "
+ << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
+ << (bytes_write >> 20UL) << " written";
+ }
+ }
+ if (page->data.size() != 0) {
+ writer.PushWrite(std::move(page));
}
- }
- if (page->data.size() != 0) {
- writer.PushWrite(std::move(page));
+ std::unique_ptr fo(
+ dmlc::Stream::Create(name_info.c_str(), "w"));
+ int tmagic = kMagic;
+ fo->Write(&tmagic, sizeof(tmagic));
+ info.SaveBinary(fo.get());
}
-
- std::unique_ptr fo(
- dmlc::Stream::Create(name_info.c_str(), "w"));
- int tmagic = kMagic;
- fo->Write(&tmagic, sizeof(tmagic));
- info.SaveBinary(fo.get());
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;
}
diff --git a/src/data/sparse_page_writer.cc b/src/data/sparse_page_writer.cc
index 33f9172d6..e16d1aee6 100644
--- a/src/data/sparse_page_writer.cc
+++ b/src/data/sparse_page_writer.cc
@@ -34,6 +34,7 @@ SparsePage::Writer::Writer(
fo->Write(format_shard);
std::unique_ptr page;
while (wqueue->Pop(&page)) {
+ if (page.get() == nullptr) break;
fmt->Write(*page, fo.get());
qrecycle_.Push(std::move(page));
}
@@ -45,7 +46,9 @@ SparsePage::Writer::Writer(
SparsePage::Writer::~Writer() {
for (auto& queue : qworkers_) {
- queue.SignalForKill();
+ // use nullptr to signal termination.
+ std::unique_ptr sig(nullptr);
+ queue.Push(std::move(sig));
}
for (auto& thread : workers_) {
thread->join();
diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc
index 7e58a060a..a48fb2f94 100644
--- a/src/gbm/gbtree.cc
+++ b/src/gbm/gbtree.cc
@@ -17,6 +17,8 @@
#include
#include "../common/common.h"
+#include "../common/random.h"
+
namespace xgboost {
namespace gbm {
@@ -47,6 +49,42 @@ struct GBTreeTrainParam : public dmlc::Parameter {
}
};
+/*! \brief training parameters */
+struct DartTrainParam : public dmlc::Parameter {
+ /*! \brief whether to not print info during training */
+ bool silent;
+ /*! \brief type of sampling algorithm */
+ int sample_type;
+ /*! \brief type of normalization algorithm */
+ int normalize_type;
+ /*! \brief how many trees are dropped */
+ float rate_drop;
+ /*! \brief whether to drop trees */
+ float skip_drop;
+ /*! \brief learning step size for a time */
+ float learning_rate;
+ // declare parameters
+ DMLC_DECLARE_PARAMETER(DartTrainParam) {
+ DMLC_DECLARE_FIELD(silent).set_default(false)
+ .describe("Not print information during trainig.");
+ DMLC_DECLARE_FIELD(sample_type).set_default(0)
+ .add_enum("uniform", 0)
+ .add_enum("weighted", 1)
+ .describe("Different types of sampling algorithm.");
+ DMLC_DECLARE_FIELD(normalize_type).set_default(0)
+ .add_enum("tree", 0)
+ .add_enum("forest", 1)
+ .describe("Different types of normalization algorithm.");
+ DMLC_DECLARE_FIELD(rate_drop).set_range(0.0f, 1.0f).set_default(0.0f)
+ .describe("Parameter of how many trees are dropped.");
+ DMLC_DECLARE_FIELD(skip_drop).set_range(0.0f, 1.0f).set_default(0.0f)
+ .describe("Parameter of whether to drop trees.");
+ DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
+ .describe("Learning rate(step size) of update.");
+ DMLC_DECLARE_ALIAS(learning_rate, eta);
+ }
+};
+
/*! \brief model parameters */
struct GBTreeModelParam : public dmlc::Parameter {
/*! \brief number of trees */
@@ -313,8 +351,9 @@ class GBTree : public GradientBooster {
}
}
// commit new trees all at once
- inline void CommitModel(std::vector >&& new_trees,
- int bst_group) {
+ virtual void
+ CommitModel(std::vector >&& new_trees,
+ int bst_group) {
for (size_t i = 0; i < new_trees.size(); ++i) {
trees.push_back(std::move(new_trees[i]));
tree_info.push_back(bst_group);
@@ -475,14 +514,236 @@ class GBTree : public GradientBooster {
std::vector > updaters;
};
+// dart
+class Dart : public GBTree {
+ public:
+ Dart() {}
+
+ void Configure(const std::vector >& cfg) override {
+ GBTree::Configure(cfg);
+ if (trees.size() == 0) {
+ dparam.InitAllowUnknown(cfg);
+ }
+ }
+
+ void Load(dmlc::Stream* fi) override {
+ GBTree::Load(fi);
+ weight_drop.resize(mparam.num_trees);
+ if (mparam.num_trees != 0) {
+ fi->Read(&weight_drop);
+ }
+ }
+
+ void Save(dmlc::Stream* fo) const override {
+ GBTree::Save(fo);
+ if (weight_drop.size() != 0) {
+ fo->Write(weight_drop);
+ }
+ }
+
+ // predict the leaf scores with dropout if ntree_limit = 0
+ void Predict(DMatrix* p_fmat,
+ int64_t buffer_offset,
+ std::vector* out_preds,
+ unsigned ntree_limit) override {
+ DropTrees(ntree_limit);
+ const MetaInfo& info = p_fmat->info();
+ int nthread;
+ #pragma omp parallel
+ {
+ nthread = omp_get_num_threads();
+ }
+ InitThreadTemp(nthread);
+ std::vector &preds = *out_preds;
+ const size_t stride = p_fmat->info().num_row * mparam.num_output_group;
+ preds.resize(stride * (mparam.size_leaf_vector+1));
+ // start collecting the prediction
+ dmlc::DataIter* iter = p_fmat->RowIterator();
+
+ iter->BeforeFirst();
+ while (iter->Next()) {
+ const RowBatch &batch = iter->Value();
+ // parallel over local batch
+ const bst_omp_uint nsize = static_cast(batch.size);
+ #pragma omp parallel for schedule(static)
+ for (bst_omp_uint i = 0; i < nsize; ++i) {
+ const int tid = omp_get_thread_num();
+ RegTree::FVec &feats = thread_temp[tid];
+ int64_t ridx = static_cast(batch.base_rowid + i);
+ CHECK_LT(static_cast(ridx), info.num_row);
+ // loop over output groups
+ for (int gid = 0; gid < mparam.num_output_group; ++gid) {
+ this->Pred(batch[i],
+ buffer_offset < 0 ? -1 : buffer_offset + ridx,
+ gid, info.GetRoot(ridx), &feats,
+ &preds[ridx * mparam.num_output_group + gid], stride,
+ ntree_limit);
+ }
+ }
+ }
+ }
+
+ void Predict(const SparseBatch::Inst& inst,
+ std::vector* out_preds,
+ unsigned ntree_limit,
+ unsigned root_index) override {
+ DropTrees(1);
+ if (thread_temp.size() == 0) {
+ thread_temp.resize(1, RegTree::FVec());
+ thread_temp[0].Init(mparam.num_feature);
+ }
+ out_preds->resize(mparam.num_output_group * (mparam.size_leaf_vector+1));
+ // loop over output groups
+ for (int gid = 0; gid < mparam.num_output_group; ++gid) {
+ this->Pred(inst, -1, gid, root_index, &thread_temp[0],
+ &(*out_preds)[gid], mparam.num_output_group,
+ ntree_limit);
+ }
+ }
+
+ protected:
+ // commit new trees all at once
+ virtual void
+ CommitModel(std::vector >&& new_trees,
+ int bst_group) {
+ for (size_t i = 0; i < new_trees.size(); ++i) {
+ trees.push_back(std::move(new_trees[i]));
+ tree_info.push_back(bst_group);
+ }
+ mparam.num_trees += static_cast(new_trees.size());
+ size_t num_drop = NormalizeTrees(new_trees.size());
+ if (dparam.silent != 1) {
+ LOG(INFO) << "drop " << num_drop << " trees, "
+ << "weight = " << weight_drop.back();
+ }
+ }
+ // predict the leaf scores without dropped trees
+ inline void Pred(const RowBatch::Inst &inst,
+ int64_t buffer_index,
+ int bst_group,
+ unsigned root_index,
+ RegTree::FVec *p_feats,
+ float *out_pred,
+ size_t stride,
+ unsigned ntree_limit) {
+ float psum = 0.0f;
+ // sum of leaf vector
+ std::vector vec_psum(mparam.size_leaf_vector, 0.0f);
+ const int64_t bid = this->BufferOffset(buffer_index, bst_group);
+ p_feats->Fill(inst);
+ for (size_t i = 0; i < trees.size(); ++i) {
+ if (tree_info[i] == bst_group) {
+ bool drop = (std::find(idx_drop.begin(), idx_drop.end(), i) != idx_drop.end());
+ if (!drop) {
+ int tid = trees[i]->GetLeafIndex(*p_feats, root_index);
+ psum += weight_drop[i] * (*trees[i])[tid].leaf_value();
+ for (int j = 0; j < mparam.size_leaf_vector; ++j) {
+ vec_psum[j] += weight_drop[i] * trees[i]->leafvec(tid)[j];
+ }
+ }
+ }
+ }
+ p_feats->Drop(inst);
+ // updated the buffered results
+ if (bid >= 0 && ntree_limit == 0) {
+ pred_counter[bid] = static_cast(trees.size());
+ pred_buffer[bid] = psum;
+ for (int i = 0; i < mparam.size_leaf_vector; ++i) {
+ pred_buffer[bid + i + 1] = vec_psum[i];
+ }
+ }
+ out_pred[0] = psum;
+ for (int i = 0; i < mparam.size_leaf_vector; ++i) {
+ out_pred[stride * (i + 1)] = vec_psum[i];
+ }
+ }
+
+ // select dropped trees
+ inline void DropTrees(unsigned ntree_limit_drop) {
+ std::uniform_real_distribution<> runif(0.0, 1.0);
+ auto& rnd = common::GlobalRandom();
+ // reset
+ idx_drop.clear();
+ // sample dropped trees
+ bool skip = false;
+ if (dparam.skip_drop > 0.0) skip = (runif(rnd) < dparam.skip_drop);
+ if (ntree_limit_drop == 0 && !skip) {
+ if (dparam.sample_type == 1) {
+ float sum_weight = 0.0;
+ for (size_t i = 0; i < weight_drop.size(); ++i) {
+ sum_weight += weight_drop[i];
+ }
+ for (size_t i = 0; i < weight_drop.size(); ++i) {
+ if (runif(rnd) < dparam.rate_drop * weight_drop.size() * weight_drop[i] / sum_weight) {
+ idx_drop.push_back(i);
+ }
+ }
+ } else {
+ for (size_t i = 0; i < weight_drop.size(); ++i) {
+ if (runif(rnd) < dparam.rate_drop) {
+ idx_drop.push_back(i);
+ }
+ }
+ }
+ }
+ }
+ // set normalization factors
+ inline size_t NormalizeTrees(size_t size_new_trees) {
+ float lr = 1.0 * dparam.learning_rate / size_new_trees;
+ size_t num_drop = idx_drop.size();
+ if (num_drop == 0) {
+ for (size_t i = 0; i < size_new_trees; ++i) {
+ weight_drop.push_back(1.0);
+ }
+ } else {
+ if (dparam.normalize_type == 1) {
+ // normalize_type 1
+ float factor = 1.0 / (1.0 + lr);
+ for (size_t i = 0; i < idx_drop.size(); ++i) {
+ weight_drop[i] *= factor;
+ }
+ for (size_t i = 0; i < size_new_trees; ++i) {
+ weight_drop.push_back(lr * factor);
+ }
+ } else {
+ // normalize_type 0
+ float factor = 1.0 * num_drop / (num_drop + lr);
+ for (size_t i = 0; i < idx_drop.size(); ++i) {
+ weight_drop[i] *= factor;
+ }
+ for (size_t i = 0; i < size_new_trees; ++i) {
+ weight_drop.push_back(1.0 * lr / (num_drop + lr));
+ }
+ }
+ }
+ // reset
+ idx_drop.clear();
+ return num_drop;
+ }
+
+ // --- data structure ---
+ // training parameter
+ DartTrainParam dparam;
+ /*! \brief prediction buffer */
+ std::vector weight_drop;
+ // indexes of dropped trees
+ std::vector idx_drop;
+};
+
// register the ojective functions
DMLC_REGISTER_PARAMETER(GBTreeModelParam);
DMLC_REGISTER_PARAMETER(GBTreeTrainParam);
+DMLC_REGISTER_PARAMETER(DartTrainParam);
XGBOOST_REGISTER_GBM(GBTree, "gbtree")
.describe("Tree booster, gradient boosted trees.")
.set_body([]() {
return new GBTree();
});
+XGBOOST_REGISTER_GBM(Dart, "dart")
+.describe("Tree booster, dart.")
+.set_body([]() {
+ return new Dart();
+ });
} // namespace gbm
} // namespace xgboost
diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py
index 386462091..710de987d 100644
--- a/tests/python/test_basic.py
+++ b/tests/python/test_basic.py
@@ -35,6 +35,22 @@ class TestBasic(unittest.TestCase):
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0
+ def test_record_results(self):
+ dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
+ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
+ param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
+ # specify validations set to watch performance
+ watchlist = [(dtest, 'eval'), (dtrain, 'train')]
+ num_round = 2
+ result = {}
+ res2 = {}
+ xgb.train(param, dtrain, num_round, watchlist,
+ callbacks=[xgb.callback.record_evaluation(result)])
+ xgb.train(param, dtrain, num_round, watchlist,
+ evals_result=res2)
+ assert result['train']['error'][0] < 0.1
+ assert res2 == result
+
def test_multiclass(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
@@ -189,5 +205,5 @@ class TestBasic(unittest.TestCase):
# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
- assert isinstance(cv, np.ndarray)
- assert cv.shape == (10, 4)
+ assert isinstance(cv, dict)
+ assert len(cv) == (4)
diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py
index 9e9c08423..5275e2f04 100644
--- a/tests/python/test_basic_models.py
+++ b/tests/python/test_basic_models.py
@@ -23,6 +23,51 @@ class TestModels(unittest.TestCase):
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
assert err < 0.1
+ def test_dart(self):
+ dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
+ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
+ param = {'max_depth': 5, 'objective': 'binary:logistic', 'booster': 'dart', 'silent': False}
+ # specify validations set to watch performance
+ watchlist = [(dtest, 'eval'), (dtrain, 'train')]
+ num_round = 2
+ bst = xgb.train(param, dtrain, num_round, watchlist)
+ # this is prediction
+ preds = bst.predict(dtest, ntree_limit=num_round)
+ labels = dtest.get_label()
+ err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
+ # error must be smaller than 10%
+ assert err < 0.1
+
+ # save dmatrix into binary buffer
+ dtest.save_binary('dtest.buffer')
+ # save model
+ bst.save_model('xgb.model.dart')
+ # load model and data in
+ bst2 = xgb.Booster(params=param, model_file='xgb.model.dart')
+ dtest2 = xgb.DMatrix('dtest.buffer')
+ preds2 = bst2.predict(dtest2, ntree_limit=num_round)
+ # assert they are the same
+ assert np.sum(np.abs(preds2 - preds)) == 0
+
+ # check whether sample_type and normalize_type work
+ num_round = 50
+ param['silent'] = True
+ param['learning_rate'] = 0.1
+ param['rate_drop'] = 0.1
+ preds_list = []
+ for p in [[p0, p1] for p0 in ['uniform', 'weighted'] for p1 in ['tree', 'forest']]:
+ param['sample_type'] = p[0]
+ param['normalize_type'] = p[1]
+ bst = xgb.train(param, dtrain, num_round, watchlist)
+ preds = bst.predict(dtest, ntree_limit=num_round)
+ err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
+ assert err < 0.1
+ preds_list.append(preds)
+
+ for ii in range(len(preds_list)):
+ for jj in range(ii + 1, len(preds_list)):
+ assert np.sum(np.abs(preds_list[ii] - preds_list[jj])) > 0
+
def test_eta_decay(self):
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4
diff --git a/tests/python/test_early_stopping.py b/tests/python/test_early_stopping.py
index b015547a1..67e725b74 100644
--- a/tests/python/test_early_stopping.py
+++ b/tests/python/test_early_stopping.py
@@ -1,5 +1,5 @@
import xgboost as xgb
-import xgboost.testing as tm
+import testing as tm
import numpy as np
import unittest
diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py
index 2391bfe28..529ef698c 100644
--- a/tests/python/test_eval_metrics.py
+++ b/tests/python/test_eval_metrics.py
@@ -1,5 +1,5 @@
import xgboost as xgb
-import xgboost.testing as tm
+import testing as tm
import numpy as np
import unittest
diff --git a/tests/python/test_plotting.py b/tests/python/test_plotting.py
index 7a70bd95e..fde98dcca 100644
--- a/tests/python/test_plotting.py
+++ b/tests/python/test_plotting.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import numpy as np
import xgboost as xgb
-import xgboost.testing as tm
+import testing as tm
import unittest
try:
diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py
index 2cb93f9ac..f7511f685 100644
--- a/tests/python/test_training_continuation.py
+++ b/tests/python/test_training_continuation.py
@@ -1,5 +1,5 @@
import xgboost as xgb
-import xgboost.testing as tm
+import testing as tm
import numpy as np
import unittest
diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py
index 9536c1e82..0bef20ec2 100644
--- a/tests/python/test_with_pandas.py
+++ b/tests/python/test_with_pandas.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import numpy as np
import xgboost as xgb
-import xgboost.testing as tm
+import testing as tm
import unittest
try:
diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py
index 72ae27948..d079d99fe 100644
--- a/tests/python/test_with_sklearn.py
+++ b/tests/python/test_with_sklearn.py
@@ -1,7 +1,7 @@
import numpy as np
import random
import xgboost as xgb
-import xgboost.testing as tm
+import testing as tm
rng = np.random.RandomState(1994)
diff --git a/python-package/xgboost/testing.py b/tests/python/testing.py
similarity index 87%
rename from python-package/xgboost/testing.py
rename to tests/python/testing.py
index 647a89fef..fb368dedd 100644
--- a/python-package/xgboost/testing.py
+++ b/tests/python/testing.py
@@ -17,6 +17,6 @@ def _skip_if_no_pandas():
def _skip_if_no_matplotlib():
try:
- import matplotlib.pyplot as plt # noqa
+ import matplotlib.pyplot as _ # noqa
except ImportError:
raise nose.SkipTest()