[PYTHON-DIST] Distributed xgboost python training API.

This commit is contained in:
tqchen 2016-02-29 10:00:37 -08:00
parent 51bb556898
commit ecb3a271be
16 changed files with 427 additions and 32 deletions

@ -1 +1 @@
Subproject commit 38ee75d95ff23e4e1febacc89e08975d9b6c6c3a
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0

View File

@ -358,6 +358,30 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
bst_ulong *out_len,
const char ***out_models);
/*!
* \brief Get string attribute from Booster.
* \param handle handle
* \param key The key of the attribute.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
const char* key,
const char** out,
int *success);
/*!
* \brief Set string attribute.
*
* \param handle handle
* \param key The key of the symbol.
* \param value The value to be saved.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
const char* key,
const char* value);
// --- Distributed training API----
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
/*!
@ -376,6 +400,6 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckPoint(BoosterHandle handle);
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
#endif // XGBOOST_C_API_H_

View File

@ -109,6 +109,21 @@ class Learner : public rabit::Serializable {
std::vector<float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const = 0;
/*!
* \brief Set additional attribute to the Booster.
* The property will be saved along the booster.
* \param key The key of the property.
* \param value The value of the property.
*/
virtual void SetAttr(const std::string& key, const std::string& value) = 0;
/*!
* \brief Get attribute from the booster.
* The property will be saved along the booster.
* \param key The key of the attribute.
* \param out The output value.
* \return Whether the key is contained in the attribute.
*/
virtual bool GetAttr(const std::string& key, std::string* out) const = 0;
/*!
* \return whether the model allow lazy checkpoint in rabit.
*/

View File

@ -10,6 +10,7 @@ import os
from .core import DMatrix, Booster
from .training import train, cv
from . import rabit
try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
from .plotting import plot_importance, plot_tree, to_graphviz

View File

@ -12,9 +12,11 @@ PY3 = (sys.version_info[0] == 3)
if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str,
py_str = lambda x: x.decode('utf-8')
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,
py_str = lambda x: x
# pandas
try:

View File

@ -12,7 +12,7 @@ import scipy.sparse
from .libpath import find_lib_path
from .compat import STRING_TYPES, PY3, DataFrame
from .compat import STRING_TYPES, PY3, DataFrame, py_str
class XGBoostError(Exception):
"""Error throwed by xgboost trainer."""
@ -658,6 +658,59 @@ class Booster(object):
"""
return self.__copy__()
def load_rabit_checkpoint(self):
"""Initialize the model by load from rabit checkpoint.
Returns
-------
version: integer
The version number of the model.
"""
version = ctypes.c_int()
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
self.handle, ctypes.byref(version)))
return version.value
def save_rabit_checkpoint(self):
"""Save the current booster to rabit checkpoint."""
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))
def attr(self, key):
"""Get attribute string from the Booster.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
_check_call(_LIB.XGBoosterGetAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None
def set_attr(self, **kwargs):
"""Set the attribute of the Booster.
Parameters
----------
**kwargs
The attributes to set
"""
for key, value in kwargs.items():
if not isinstance(value, STRING_TYPES):
raise ValueError("Set Attr only accepts string values")
_check_call(_LIB.XGBoosterSetAttr(
self.handle, c_str(key), c_str(str(value))))
def set_param(self, params, value=None):
"""Set parameters into the Booster.

View File

@ -0,0 +1,188 @@
"""Distributed XGBoost Rabit related API."""
from __future__ import absolute_import
import sys
import atexit
import ctypes
import numpy as np
from .core import _LIB, c_str, STRING_TYPES
def _init_rabit():
"""Initialize the rabit library."""
_LIB.RabitGetRank.restype = ctypes.c_int
_LIB.RabitGetWorldSize.restype = ctypes.c_int
_LIB.RabitVersionNumber.restype = ctypes.c_int
_LIB.RabitInit(0, None)
def finalize():
"""Finalize the process, notify tracker everything is done."""
_LIB.RabitFinalize()
def get_rank():
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
ret = _LIB.RabitGetRank()
return ret
def get_world_size():
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
ret = _LIB.RabitGetWorldSize()
return ret
def tracker_print(msg):
"""Print message to the tracker.
This function can be used to communicate the information of
the progress to the tracker
Parameters
----------
msg : str
The message to be printed to tracker.
"""
if not isinstance(msg, STRING_TYPES):
msg = str(msg)
_LIB.RabitTrackerPrint(c_str(msg))
def get_processor_name():
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
mxlen = 256
length = ctypes.c_ulong()
buf = ctypes.create_string_buffer(mxlen)
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
return buf.value
def broadcast(data, root):
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
_LIB.RabitBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), root)
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root)
data = pickle.loads(dptr.raw)
del dptr
else:
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root)
del s
return data
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8') : 0,
np.dtype('uint8') : 1,
np.dtype('int32') : 2,
np.dtype('uint32') : 3,
np.dtype('int64') : 4,
np.dtype('uint64') : 5,
np.dtype('float32') : 6,
np.dtype('float64') : 7
}
def allreduce(data, op, prepare_fun=None):
"""Perform allreduce, return the result.
Parameters
----------
data: numpy array
Input data.
op: int
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun: function
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to intialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result : array_like
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if not isinstance(data, np.ndarray):
raise Exception('allreduce only takes in numpy.ndarray')
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception('data type %s not supported' % str(buf.dtype))
if prepare_fun is None:
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, None, None)
else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(args):
"""prepare function."""
prepare_fun(data)
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
op, func_ptr(pfunc), None)
return buf
def version_number():
"""Returns version number of current stored model.
This means how many calls to CheckPoint we made so far.
Returns
-------
version : int
Version number of currently stored model
"""
ret = _LIB.RabitVersionNumber()
return ret
# intialization script
_init_rabit()

View File

@ -6,9 +6,11 @@ from __future__ import absolute_import
import sys
import re
import os
import numpy as np
from .core import Booster, STRING_TYPES
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold, XGBKFold)
from . import rabit
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize=False, early_stopping_rounds=None, evals_result=None,
@ -94,6 +96,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
verbose_eval_every_line = verbose_eval
verbose_eval = True if verbose_eval_every_line > 0 else False
if rabit.get_rank() != 0:
verbose_eval = False;
if xgb_model is not None:
if not isinstance(xgb_model, STRING_TYPES):
xgb_model = xgb_model.save_raw()
@ -123,14 +128,14 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
raise ValueError('For early stopping you need at least one set in evals.')
if verbose_eval:
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(
rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format(
evals[-1][1], early_stopping_rounds))
# is params a list of tuples? are we using multiple eval metrics?
if isinstance(params, list):
if len(params) != len(dict(params).items()):
params = dict(params)
sys.stderr.write("Multiple eval metrics have been passed: " \
rabit.tracker_print("Multiple eval metrics have been passed: " \
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
else:
params = dict(params)
@ -145,23 +150,35 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize_score = maximize
if maximize_score:
best_score = 0.0
bst.set_attr(best_score='0.0')
else:
best_score = float('inf')
best_msg = ''
best_score_i = (nboost - 1)
bst.set_attr(best_score='inf')
bst.set_attr(best_iteration='0')
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
for i in range(nboost, nboost + num_boost_round):
# Distributed code: Load the checkpoint from rabit.
version = bst.load_rabit_checkpoint()
assert(rabit.get_world_size() != 1 or version == 0)
start_iteration = int(version / 2)
nboost += start_iteration
for i in range(start_iteration, num_boost_round):
if learning_rates is not None:
if isinstance(learning_rates, list):
bst.set_param({'eta': learning_rates[i]})
else:
bst.set_param({'eta': learning_rates(i, num_boost_round)})
# Distributed code: need to resume to this point.
# Skip the first update if it is a recovery step.
if version % 2 == 0:
bst.update(dtrain, i, obj)
bst.save_rabit_checkpoint()
version += 1
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
nboost += 1
# check evaluation result.
@ -176,9 +193,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
if verbose_eval:
if verbose_eval_every_line:
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
sys.stderr.write(msg + '\n')
rabit.tracker_print(msg + '\n')
else:
sys.stderr.write(msg + '\n')
rabit.tracker_print(msg + '\n')
if evals_result is not None:
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
@ -196,22 +213,26 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
if early_stopping_rounds:
score = float(msg.rsplit(':', 1)[1])
best_score = float(bst.attr('best_score'))
best_iteration = int(bst.attr('best_iteration'))
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
best_score = score
best_score_i = (nboost - 1)
best_msg = msg
elif i - best_score_i >= early_stopping_rounds:
# save the property to attributes, so they will occur in checkpoint.
bst.set_attr(best_score=str(score),
best_iteration=str(nboost - 1),
best_msg=msg)
elif i - best_iteration >= early_stopping_rounds:
best_msg = bst.attr('best_msg')
if verbose_eval:
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
# best iteration will be assigned in the end.
bst.best_score = best_score
bst.best_iteration = best_score_i
rabit.tracker_print("Stopping. Best iteration:\n{}\n\n".format(best_msg))
break
# do checkpoint after evaluation, in case evaluation also updates booster.
bst.save_rabit_checkpoint()
version += 1
if early_stopping_rounds:
best_score = best_score
bst.best_iteration = best_score_i
bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration'))
else:
bst.best_iteration = nboost - 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree

2
rabit

@ -1 +1 @@
Subproject commit 56ec4263f9a70a315c1f153dc5897b7c1b58250c
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0

View File

@ -91,7 +91,7 @@ int XGDMatrixCreateFromFile(const char *fname,
<< "will split data among workers";
}
*out = DMatrix::Load(
fname, silent != 0, false);
fname, false, true);
API_END();
}
@ -533,18 +533,44 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
API_END();
}
int XGBoosterGetAttr(BoosterHandle handle,
const char* key,
const char** out,
int* success) {
Booster* bst = static_cast<Booster*>(handle);
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN();
if (bst->learner()->GetAttr(key, &ret_str)) {
*out = ret_str.c_str();
*success = 1;
} else {
*out = nullptr;
*success = 0;
}
API_END();
}
int XGBoosterSetAttr(BoosterHandle handle,
const char* key,
const char* value) {
Booster* bst = static_cast<Booster*>(handle);
API_BEGIN();
bst->learner()->SetAttr(key, value);
API_END();
}
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
*version = rabit::LoadCheckPoint(bst->learner());
if (version != 0) {
if (*version != 0) {
bst->initialized_ = true;
}
API_END();
}
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
if (bst->learner()->AllowLazyCheckPoint()) {

View File

@ -184,7 +184,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
<< " of " << npart << " parts";
}
// legacy handling of binary data loading
if (file_format == "auto" && !load_row_split) {
if (file_format == "auto" && npart == 1) {
int magic;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi.get() != nullptr) {

View File

@ -5,6 +5,7 @@
* \author Tianqi Chen
*/
#include <xgboost/learner.h>
#include <dmlc/io.h>
#include <algorithm>
#include <vector>
#include <utility>
@ -43,8 +44,10 @@ struct LearnerModelParam
unsigned num_feature;
/* \brief number of classes, if it is multi-class classification */
int num_class;
/*! \brief Model contain additional properties */
int contain_extra_attrs;
/*! \brief reserved field */
int reserved[31];
int reserved[30];
/*! \brief constructor */
LearnerModelParam() {
std::memset(this, 0, sizeof(LearnerModelParam));
@ -243,6 +246,12 @@ class LearnerImpl : public Learner {
obj_.reset(ObjFunction::Create(name_obj_));
gbm_.reset(GradientBooster::Create(name_gbm_));
gbm_->Load(fi);
if (mparam.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr;
fi->Read(&attr);
attributes_ = std::map<std::string, std::string>(
attr.begin(), attr.end());
}
if (metrics_.size() == 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
@ -259,6 +268,11 @@ class LearnerImpl : public Learner {
fo->Write(name_obj_);
fo->Write(name_gbm_);
gbm_->Save(fo);
if (mparam.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr(
attributes_.begin(), attributes_.end());
fo->Write(attr);
}
}
void UpdateOneIter(int iter, DMatrix* train) override {
@ -300,6 +314,18 @@ class LearnerImpl : public Learner {
return os.str();
}
void SetAttr(const std::string& key, const std::string& value) override {
attributes_[key] = value;
mparam.contain_extra_attrs = 1;
}
bool GetAttr(const std::string& key, std::string* out) const override {
auto it = attributes_.find(key);
if (it == attributes_.end()) return false;
*out = it->second;
return true;
}
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
if (metric == "auto") metric = obj_->DefaultEvalMetric();
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
@ -427,6 +453,8 @@ class LearnerImpl : public Learner {
LearnerTrainParam tparam;
// configurations
std::map<std::string, std::string> cfg_;
// attributes
std::map<std::string, std::string> attributes_;
// name of gbm
std::string name_gbm_;
// name of objective functon

View File

@ -56,6 +56,9 @@ class BaseMaker: public TreeUpdater {
}
}
}
}
/*! \brief synchronize the information */
inline void SyncInfo() {
rabit::Allreduce<rabit::op::Max>(dmlc::BeginPtr(fminmax), fminmax.size());
}
// get feature type, 0:empty 1:binary 2:real

View File

@ -313,6 +313,7 @@ class CQHistMaker: public HistMaker<TStats> {
feat_helper.InitByCol(p_fmat, tree);
cache_dmatrix_ = p_fmat;
}
feat_helper.SyncInfo();
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
}
// code to create histogram

4
tests/distributed/runtests.sh Executable file
View File

@ -0,0 +1,4 @@
#!/bin/bash
PYTHONPATH=../../python-package/ ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3\
python test_basic.py

View File

@ -0,0 +1,29 @@
#!/usr/bin/python
import numpy as np
import scipy.sparse
import pickle
import xgboost as xgb
# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
# specify parameters via map, definition are same as c++ version
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
# specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20
# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
# save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()