[PYTHON-DIST] Distributed xgboost python training API.
This commit is contained in:
parent
51bb556898
commit
ecb3a271be
@ -1 +1 @@
|
||||
Subproject commit 38ee75d95ff23e4e1febacc89e08975d9b6c6c3a
|
||||
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0
|
||||
@ -358,6 +358,30 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
bst_ulong *out_len,
|
||||
const char ***out_models);
|
||||
|
||||
/*!
|
||||
* \brief Get string attribute from Booster.
|
||||
* \param handle handle
|
||||
* \param key The key of the attribute.
|
||||
* \param out The result attribute, can be NULL if the attribute do not exist.
|
||||
* \param success Whether the result is contained in out.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char** out,
|
||||
int *success);
|
||||
/*!
|
||||
* \brief Set string attribute.
|
||||
*
|
||||
* \param handle handle
|
||||
* \param key The key of the symbol.
|
||||
* \param value The value to be saved.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char* value);
|
||||
|
||||
// --- Distributed training API----
|
||||
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
|
||||
/*!
|
||||
@ -376,6 +400,6 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(
|
||||
* \param handle handle
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSaveRabitCheckPoint(BoosterHandle handle);
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
|
||||
|
||||
#endif // XGBOOST_C_API_H_
|
||||
|
||||
@ -109,6 +109,21 @@ class Learner : public rabit::Serializable {
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
bool pred_leaf = false) const = 0;
|
||||
/*!
|
||||
* \brief Set additional attribute to the Booster.
|
||||
* The property will be saved along the booster.
|
||||
* \param key The key of the property.
|
||||
* \param value The value of the property.
|
||||
*/
|
||||
virtual void SetAttr(const std::string& key, const std::string& value) = 0;
|
||||
/*!
|
||||
* \brief Get attribute from the booster.
|
||||
* The property will be saved along the booster.
|
||||
* \param key The key of the attribute.
|
||||
* \param out The output value.
|
||||
* \return Whether the key is contained in the attribute.
|
||||
*/
|
||||
virtual bool GetAttr(const std::string& key, std::string* out) const = 0;
|
||||
/*!
|
||||
* \return whether the model allow lazy checkpoint in rabit.
|
||||
*/
|
||||
|
||||
@ -10,6 +10,7 @@ import os
|
||||
|
||||
from .core import DMatrix, Booster
|
||||
from .training import train, cv
|
||||
from . import rabit
|
||||
try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
|
||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||
|
||||
@ -12,9 +12,11 @@ PY3 = (sys.version_info[0] == 3)
|
||||
if PY3:
|
||||
# pylint: disable=invalid-name, redefined-builtin
|
||||
STRING_TYPES = str,
|
||||
py_str = lambda x: x.decode('utf-8')
|
||||
else:
|
||||
# pylint: disable=invalid-name
|
||||
STRING_TYPES = basestring,
|
||||
py_str = lambda x: x
|
||||
|
||||
# pandas
|
||||
try:
|
||||
|
||||
@ -12,7 +12,7 @@ import scipy.sparse
|
||||
|
||||
from .libpath import find_lib_path
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, py_str
|
||||
|
||||
class XGBoostError(Exception):
|
||||
"""Error throwed by xgboost trainer."""
|
||||
@ -654,10 +654,63 @@ class Booster(object):
|
||||
Returns
|
||||
-------
|
||||
booster: `Booster`
|
||||
a copied booster model
|
||||
a copied booster model
|
||||
"""
|
||||
return self.__copy__()
|
||||
|
||||
def load_rabit_checkpoint(self):
|
||||
"""Initialize the model by load from rabit checkpoint.
|
||||
|
||||
Returns
|
||||
-------
|
||||
version: integer
|
||||
The version number of the model.
|
||||
"""
|
||||
version = ctypes.c_int()
|
||||
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
|
||||
self.handle, ctypes.byref(version)))
|
||||
return version.value
|
||||
|
||||
def save_rabit_checkpoint(self):
|
||||
"""Save the current booster to rabit checkpoint."""
|
||||
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))
|
||||
|
||||
def attr(self, key):
|
||||
"""Get attribute string from the Booster.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The key to get attribute from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
value : str
|
||||
The attribute value of the key, returns None if attribute do not exist.
|
||||
"""
|
||||
ret = ctypes.c_char_p()
|
||||
success = ctypes.c_int()
|
||||
_check_call(_LIB.XGBoosterGetAttr(
|
||||
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
|
||||
if success.value != 0:
|
||||
return py_str(ret.value)
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_attr(self, **kwargs):
|
||||
"""Set the attribute of the Booster.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs
|
||||
The attributes to set
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not isinstance(value, STRING_TYPES):
|
||||
raise ValueError("Set Attr only accepts string values")
|
||||
_check_call(_LIB.XGBoosterSetAttr(
|
||||
self.handle, c_str(key), c_str(str(value))))
|
||||
|
||||
def set_param(self, params, value=None):
|
||||
"""Set parameters into the Booster.
|
||||
|
||||
|
||||
188
python-package/xgboost/rabit.py
Normal file
188
python-package/xgboost/rabit.py
Normal file
@ -0,0 +1,188 @@
|
||||
"""Distributed XGBoost Rabit related API."""
|
||||
from __future__ import absolute_import
|
||||
import sys
|
||||
import atexit
|
||||
import ctypes
|
||||
import numpy as np
|
||||
|
||||
from .core import _LIB, c_str, STRING_TYPES
|
||||
|
||||
def _init_rabit():
|
||||
"""Initialize the rabit library."""
|
||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||
_LIB.RabitInit(0, None)
|
||||
|
||||
|
||||
def finalize():
|
||||
"""Finalize the process, notify tracker everything is done."""
|
||||
_LIB.RabitFinalize()
|
||||
|
||||
|
||||
def get_rank():
|
||||
"""Get rank of current process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
Rank of current process.
|
||||
"""
|
||||
ret = _LIB.RabitGetRank()
|
||||
return ret
|
||||
|
||||
|
||||
def get_world_size():
|
||||
"""Get total number workers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
n : int
|
||||
Total number of process.
|
||||
"""
|
||||
ret = _LIB.RabitGetWorldSize()
|
||||
return ret
|
||||
|
||||
|
||||
def tracker_print(msg):
|
||||
"""Print message to the tracker.
|
||||
|
||||
This function can be used to communicate the information of
|
||||
the progress to the tracker
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : str
|
||||
The message to be printed to tracker.
|
||||
"""
|
||||
if not isinstance(msg, STRING_TYPES):
|
||||
msg = str(msg)
|
||||
_LIB.RabitTrackerPrint(c_str(msg))
|
||||
|
||||
|
||||
def get_processor_name():
|
||||
"""Get the processor name.
|
||||
|
||||
Returns
|
||||
-------
|
||||
name : str
|
||||
the name of processor(host)
|
||||
"""
|
||||
mxlen = 256
|
||||
length = ctypes.c_ulong()
|
||||
buf = ctypes.create_string_buffer(mxlen)
|
||||
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
||||
return buf.value
|
||||
|
||||
|
||||
def broadcast(data, root):
|
||||
"""Broadcast object from one node to all other nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : any type that can be pickled
|
||||
Input data, if current rank does not equal root, this can be None
|
||||
root : int
|
||||
Rank of the node to broadcast data from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object : int
|
||||
the result of broadcast.
|
||||
"""
|
||||
rank = get_rank()
|
||||
length = ctypes.c_ulong()
|
||||
if root == rank:
|
||||
assert data is not None, 'need to pass in data when broadcasting'
|
||||
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
length.value = len(s)
|
||||
# run first broadcast
|
||||
_LIB.RabitBroadcast(ctypes.byref(length),
|
||||
ctypes.sizeof(ctypes.c_ulong), root)
|
||||
if root != rank:
|
||||
dptr = (ctypes.c_char * length.value)()
|
||||
# run second
|
||||
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
||||
length.value, root)
|
||||
data = pickle.loads(dptr.raw)
|
||||
del dptr
|
||||
else:
|
||||
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
||||
length.value, root)
|
||||
del s
|
||||
return data
|
||||
|
||||
# enumeration of dtypes
|
||||
DTYPE_ENUM__ = {
|
||||
np.dtype('int8') : 0,
|
||||
np.dtype('uint8') : 1,
|
||||
np.dtype('int32') : 2,
|
||||
np.dtype('uint32') : 3,
|
||||
np.dtype('int64') : 4,
|
||||
np.dtype('uint64') : 5,
|
||||
np.dtype('float32') : 6,
|
||||
np.dtype('float64') : 7
|
||||
}
|
||||
|
||||
|
||||
def allreduce(data, op, prepare_fun=None):
|
||||
"""Perform allreduce, return the result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: numpy array
|
||||
Input data.
|
||||
op: int
|
||||
Reduction operators, can be MIN, MAX, SUM, BITOR
|
||||
prepare_fun: function
|
||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||
will be called by the function before performing allreduce, to intialize the data
|
||||
If the result of Allreduce can be recovered directly,
|
||||
then prepare_fun will NOT be called
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : array_like
|
||||
The result of allreduce, have same shape as data
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is not thread-safe.
|
||||
"""
|
||||
if not isinstance(data, np.ndarray):
|
||||
raise Exception('allreduce only takes in numpy.ndarray')
|
||||
buf = data.ravel()
|
||||
if buf.base is data.base:
|
||||
buf = buf.copy()
|
||||
if buf.dtype not in DTYPE_ENUM__:
|
||||
raise Exception('data type %s not supported' % str(buf.dtype))
|
||||
if prepare_fun is None:
|
||||
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, None, None)
|
||||
else:
|
||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
def pfunc(args):
|
||||
"""prepare function."""
|
||||
prepare_fun(data)
|
||||
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, func_ptr(pfunc), None)
|
||||
return buf
|
||||
|
||||
|
||||
def version_number():
|
||||
"""Returns version number of current stored model.
|
||||
|
||||
This means how many calls to CheckPoint we made so far.
|
||||
|
||||
Returns
|
||||
-------
|
||||
version : int
|
||||
Version number of currently stored model
|
||||
"""
|
||||
ret = _LIB.RabitVersionNumber()
|
||||
return ret
|
||||
|
||||
# intialization script
|
||||
_init_rabit()
|
||||
@ -6,9 +6,11 @@ from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
from .core import Booster, STRING_TYPES
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold, XGBKFold)
|
||||
from . import rabit
|
||||
|
||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
maximize=False, early_stopping_rounds=None, evals_result=None,
|
||||
@ -94,6 +96,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
verbose_eval_every_line = verbose_eval
|
||||
verbose_eval = True if verbose_eval_every_line > 0 else False
|
||||
|
||||
if rabit.get_rank() != 0:
|
||||
verbose_eval = False;
|
||||
|
||||
if xgb_model is not None:
|
||||
if not isinstance(xgb_model, STRING_TYPES):
|
||||
xgb_model = xgb_model.save_raw()
|
||||
@ -123,15 +128,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
raise ValueError('For early stopping you need at least one set in evals.')
|
||||
|
||||
if verbose_eval:
|
||||
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
||||
rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
||||
evals[-1][1], early_stopping_rounds))
|
||||
|
||||
# is params a list of tuples? are we using multiple eval metrics?
|
||||
if isinstance(params, list):
|
||||
if len(params) != len(dict(params).items()):
|
||||
params = dict(params)
|
||||
sys.stderr.write("Multiple eval metrics have been passed: " \
|
||||
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
|
||||
rabit.tracker_print("Multiple eval metrics have been passed: " \
|
||||
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
|
||||
else:
|
||||
params = dict(params)
|
||||
|
||||
@ -145,23 +150,35 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
maximize_score = maximize
|
||||
|
||||
if maximize_score:
|
||||
best_score = 0.0
|
||||
bst.set_attr(best_score='0.0')
|
||||
else:
|
||||
best_score = float('inf')
|
||||
|
||||
best_msg = ''
|
||||
best_score_i = (nboost - 1)
|
||||
bst.set_attr(best_score='inf')
|
||||
bst.set_attr(best_iteration='0')
|
||||
|
||||
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||
|
||||
for i in range(nboost, nboost + num_boost_round):
|
||||
# Distributed code: Load the checkpoint from rabit.
|
||||
version = bst.load_rabit_checkpoint()
|
||||
assert(rabit.get_world_size() != 1 or version == 0)
|
||||
start_iteration = int(version / 2)
|
||||
nboost += start_iteration
|
||||
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
if learning_rates is not None:
|
||||
if isinstance(learning_rates, list):
|
||||
bst.set_param({'eta': learning_rates[i]})
|
||||
else:
|
||||
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
||||
bst.update(dtrain, i, obj)
|
||||
|
||||
# Distributed code: need to resume to this point.
|
||||
# Skip the first update if it is a recovery step.
|
||||
if version % 2 == 0:
|
||||
bst.update(dtrain, i, obj)
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
|
||||
|
||||
nboost += 1
|
||||
# check evaluation result.
|
||||
@ -176,9 +193,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
if verbose_eval:
|
||||
if verbose_eval_every_line:
|
||||
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
|
||||
sys.stderr.write(msg + '\n')
|
||||
rabit.tracker_print(msg + '\n')
|
||||
else:
|
||||
sys.stderr.write(msg + '\n')
|
||||
rabit.tracker_print(msg + '\n')
|
||||
|
||||
if evals_result is not None:
|
||||
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
||||
@ -196,22 +213,26 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
|
||||
if early_stopping_rounds:
|
||||
score = float(msg.rsplit(':', 1)[1])
|
||||
best_score = float(bst.attr('best_score'))
|
||||
best_iteration = int(bst.attr('best_iteration'))
|
||||
if (maximize_score and score > best_score) or \
|
||||
(not maximize_score and score < best_score):
|
||||
best_score = score
|
||||
best_score_i = (nboost - 1)
|
||||
best_msg = msg
|
||||
elif i - best_score_i >= early_stopping_rounds:
|
||||
# save the property to attributes, so they will occur in checkpoint.
|
||||
bst.set_attr(best_score=str(score),
|
||||
best_iteration=str(nboost - 1),
|
||||
best_msg=msg)
|
||||
elif i - best_iteration >= early_stopping_rounds:
|
||||
best_msg = bst.attr('best_msg')
|
||||
if verbose_eval:
|
||||
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||
# best iteration will be assigned in the end.
|
||||
bst.best_score = best_score
|
||||
bst.best_iteration = best_score_i
|
||||
rabit.tracker_print("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||
break
|
||||
# do checkpoint after evaluation, in case evaluation also updates booster.
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
if early_stopping_rounds:
|
||||
best_score = best_score
|
||||
bst.best_iteration = best_score_i
|
||||
bst.best_score = float(bst.attr('best_score'))
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
else:
|
||||
bst.best_iteration = nboost - 1
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit 56ec4263f9a70a315c1f153dc5897b7c1b58250c
|
||||
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
|
||||
@ -91,7 +91,7 @@ int XGDMatrixCreateFromFile(const char *fname,
|
||||
<< "will split data among workers";
|
||||
}
|
||||
*out = DMatrix::Load(
|
||||
fname, silent != 0, false);
|
||||
fname, false, true);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -533,18 +533,44 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterGetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char** out,
|
||||
int* success) {
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
API_BEGIN();
|
||||
if (bst->learner()->GetAttr(key, &ret_str)) {
|
||||
*out = ret_str.c_str();
|
||||
*success = 1;
|
||||
} else {
|
||||
*out = nullptr;
|
||||
*success = 0;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterSetAttr(BoosterHandle handle,
|
||||
const char* key,
|
||||
const char* value) {
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
API_BEGIN();
|
||||
bst->learner()->SetAttr(key, value);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst->learner());
|
||||
if (version != 0) {
|
||||
if (*version != 0) {
|
||||
bst->initialized_ = true;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
|
||||
int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
if (bst->learner()->AllowLazyCheckPoint()) {
|
||||
|
||||
@ -184,7 +184,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
||||
<< " of " << npart << " parts";
|
||||
}
|
||||
// legacy handling of binary data loading
|
||||
if (file_format == "auto" && !load_row_split) {
|
||||
if (file_format == "auto" && npart == 1) {
|
||||
int magic;
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
||||
if (fi.get() != nullptr) {
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/learner.h>
|
||||
#include <dmlc/io.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
@ -43,8 +44,10 @@ struct LearnerModelParam
|
||||
unsigned num_feature;
|
||||
/* \brief number of classes, if it is multi-class classification */
|
||||
int num_class;
|
||||
/*! \brief Model contain additional properties */
|
||||
int contain_extra_attrs;
|
||||
/*! \brief reserved field */
|
||||
int reserved[31];
|
||||
int reserved[30];
|
||||
/*! \brief constructor */
|
||||
LearnerModelParam() {
|
||||
std::memset(this, 0, sizeof(LearnerModelParam));
|
||||
@ -243,6 +246,12 @@ class LearnerImpl : public Learner {
|
||||
obj_.reset(ObjFunction::Create(name_obj_));
|
||||
gbm_.reset(GradientBooster::Create(name_gbm_));
|
||||
gbm_->Load(fi);
|
||||
if (mparam.contain_extra_attrs != 0) {
|
||||
std::vector<std::pair<std::string, std::string> > attr;
|
||||
fi->Read(&attr);
|
||||
attributes_ = std::map<std::string, std::string>(
|
||||
attr.begin(), attr.end());
|
||||
}
|
||||
|
||||
if (metrics_.size() == 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||
@ -259,6 +268,11 @@ class LearnerImpl : public Learner {
|
||||
fo->Write(name_obj_);
|
||||
fo->Write(name_gbm_);
|
||||
gbm_->Save(fo);
|
||||
if (mparam.contain_extra_attrs != 0) {
|
||||
std::vector<std::pair<std::string, std::string> > attr(
|
||||
attributes_.begin(), attributes_.end());
|
||||
fo->Write(attr);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||
@ -300,6 +314,18 @@ class LearnerImpl : public Learner {
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void SetAttr(const std::string& key, const std::string& value) override {
|
||||
attributes_[key] = value;
|
||||
mparam.contain_extra_attrs = 1;
|
||||
}
|
||||
|
||||
bool GetAttr(const std::string& key, std::string* out) const override {
|
||||
auto it = attributes_.find(key);
|
||||
if (it == attributes_.end()) return false;
|
||||
*out = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
|
||||
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
||||
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
||||
@ -427,6 +453,8 @@ class LearnerImpl : public Learner {
|
||||
LearnerTrainParam tparam;
|
||||
// configurations
|
||||
std::map<std::string, std::string> cfg_;
|
||||
// attributes
|
||||
std::map<std::string, std::string> attributes_;
|
||||
// name of gbm
|
||||
std::string name_gbm_;
|
||||
// name of objective functon
|
||||
|
||||
@ -56,6 +56,9 @@ class BaseMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief synchronize the information */
|
||||
inline void SyncInfo() {
|
||||
rabit::Allreduce<rabit::op::Max>(dmlc::BeginPtr(fminmax), fminmax.size());
|
||||
}
|
||||
// get feature type, 0:empty 1:binary 2:real
|
||||
|
||||
@ -313,6 +313,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
feat_helper.InitByCol(p_fmat, tree);
|
||||
cache_dmatrix_ = p_fmat;
|
||||
}
|
||||
feat_helper.SyncInfo();
|
||||
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
|
||||
}
|
||||
// code to create histogram
|
||||
|
||||
4
tests/distributed/runtests.sh
Executable file
4
tests/distributed/runtests.sh
Executable file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
PYTHONPATH=../../python-package/ ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3\
|
||||
python test_basic.py
|
||||
29
tests/distributed/test_basic.py
Normal file
29
tests/distributed/test_basic.py
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/python
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
import pickle
|
||||
import xgboost as xgb
|
||||
|
||||
# Load file, file will be automatically sharded in distributed mode.
|
||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||
|
||||
# specify parameters via map, definition are same as c++ version
|
||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||
num_round = 20
|
||||
|
||||
# Run training, all the features in training API is available.
|
||||
# Currently, this script only support calling train once for fault recovery purpose.
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||
|
||||
# save the model, only ask process 0 to save the model.
|
||||
if xgb.rabit.get_rank() == 0:
|
||||
bst.save_model("test.model")
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
xgb.rabit.finalize()
|
||||
Loading…
x
Reference in New Issue
Block a user