[PYTHON-DIST] Distributed xgboost python training API.
This commit is contained in:
parent
51bb556898
commit
ecb3a271be
@ -1 +1 @@
|
|||||||
Subproject commit 38ee75d95ff23e4e1febacc89e08975d9b6c6c3a
|
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0
|
||||||
@ -358,6 +358,30 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
|||||||
bst_ulong *out_len,
|
bst_ulong *out_len,
|
||||||
const char ***out_models);
|
const char ***out_models);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get string attribute from Booster.
|
||||||
|
* \param handle handle
|
||||||
|
* \param key The key of the attribute.
|
||||||
|
* \param out The result attribute, can be NULL if the attribute do not exist.
|
||||||
|
* \param success Whether the result is contained in out.
|
||||||
|
* \return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
|
||||||
|
const char* key,
|
||||||
|
const char** out,
|
||||||
|
int *success);
|
||||||
|
/*!
|
||||||
|
* \brief Set string attribute.
|
||||||
|
*
|
||||||
|
* \param handle handle
|
||||||
|
* \param key The key of the symbol.
|
||||||
|
* \param value The value to be saved.
|
||||||
|
* \return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
|
||||||
|
const char* key,
|
||||||
|
const char* value);
|
||||||
|
|
||||||
// --- Distributed training API----
|
// --- Distributed training API----
|
||||||
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
|
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
|
||||||
/*!
|
/*!
|
||||||
@ -376,6 +400,6 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(
|
|||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \return 0 when success, -1 when failure happens
|
* \return 0 when success, -1 when failure happens
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGBoosterSaveRabitCheckPoint(BoosterHandle handle);
|
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
|
||||||
|
|
||||||
#endif // XGBOOST_C_API_H_
|
#endif // XGBOOST_C_API_H_
|
||||||
|
|||||||
@ -109,6 +109,21 @@ class Learner : public rabit::Serializable {
|
|||||||
std::vector<float> *out_preds,
|
std::vector<float> *out_preds,
|
||||||
unsigned ntree_limit = 0,
|
unsigned ntree_limit = 0,
|
||||||
bool pred_leaf = false) const = 0;
|
bool pred_leaf = false) const = 0;
|
||||||
|
/*!
|
||||||
|
* \brief Set additional attribute to the Booster.
|
||||||
|
* The property will be saved along the booster.
|
||||||
|
* \param key The key of the property.
|
||||||
|
* \param value The value of the property.
|
||||||
|
*/
|
||||||
|
virtual void SetAttr(const std::string& key, const std::string& value) = 0;
|
||||||
|
/*!
|
||||||
|
* \brief Get attribute from the booster.
|
||||||
|
* The property will be saved along the booster.
|
||||||
|
* \param key The key of the attribute.
|
||||||
|
* \param out The output value.
|
||||||
|
* \return Whether the key is contained in the attribute.
|
||||||
|
*/
|
||||||
|
virtual bool GetAttr(const std::string& key, std::string* out) const = 0;
|
||||||
/*!
|
/*!
|
||||||
* \return whether the model allow lazy checkpoint in rabit.
|
* \return whether the model allow lazy checkpoint in rabit.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import os
|
|||||||
|
|
||||||
from .core import DMatrix, Booster
|
from .core import DMatrix, Booster
|
||||||
from .training import train, cv
|
from .training import train, cv
|
||||||
|
from . import rabit
|
||||||
try:
|
try:
|
||||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
|
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
|
||||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||||
|
|||||||
@ -12,9 +12,11 @@ PY3 = (sys.version_info[0] == 3)
|
|||||||
if PY3:
|
if PY3:
|
||||||
# pylint: disable=invalid-name, redefined-builtin
|
# pylint: disable=invalid-name, redefined-builtin
|
||||||
STRING_TYPES = str,
|
STRING_TYPES = str,
|
||||||
|
py_str = lambda x: x.decode('utf-8')
|
||||||
else:
|
else:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
STRING_TYPES = basestring,
|
STRING_TYPES = basestring,
|
||||||
|
py_str = lambda x: x
|
||||||
|
|
||||||
# pandas
|
# pandas
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import scipy.sparse
|
|||||||
|
|
||||||
from .libpath import find_lib_path
|
from .libpath import find_lib_path
|
||||||
|
|
||||||
from .compat import STRING_TYPES, PY3, DataFrame
|
from .compat import STRING_TYPES, PY3, DataFrame, py_str
|
||||||
|
|
||||||
class XGBoostError(Exception):
|
class XGBoostError(Exception):
|
||||||
"""Error throwed by xgboost trainer."""
|
"""Error throwed by xgboost trainer."""
|
||||||
@ -654,10 +654,63 @@ class Booster(object):
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
booster: `Booster`
|
booster: `Booster`
|
||||||
a copied booster model
|
a copied booster model
|
||||||
"""
|
"""
|
||||||
return self.__copy__()
|
return self.__copy__()
|
||||||
|
|
||||||
|
def load_rabit_checkpoint(self):
|
||||||
|
"""Initialize the model by load from rabit checkpoint.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
version: integer
|
||||||
|
The version number of the model.
|
||||||
|
"""
|
||||||
|
version = ctypes.c_int()
|
||||||
|
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
|
||||||
|
self.handle, ctypes.byref(version)))
|
||||||
|
return version.value
|
||||||
|
|
||||||
|
def save_rabit_checkpoint(self):
|
||||||
|
"""Save the current booster to rabit checkpoint."""
|
||||||
|
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))
|
||||||
|
|
||||||
|
def attr(self, key):
|
||||||
|
"""Get attribute string from the Booster.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
The key to get attribute from.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
value : str
|
||||||
|
The attribute value of the key, returns None if attribute do not exist.
|
||||||
|
"""
|
||||||
|
ret = ctypes.c_char_p()
|
||||||
|
success = ctypes.c_int()
|
||||||
|
_check_call(_LIB.XGBoosterGetAttr(
|
||||||
|
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
|
||||||
|
if success.value != 0:
|
||||||
|
return py_str(ret.value)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_attr(self, **kwargs):
|
||||||
|
"""Set the attribute of the Booster.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
**kwargs
|
||||||
|
The attributes to set
|
||||||
|
"""
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if not isinstance(value, STRING_TYPES):
|
||||||
|
raise ValueError("Set Attr only accepts string values")
|
||||||
|
_check_call(_LIB.XGBoosterSetAttr(
|
||||||
|
self.handle, c_str(key), c_str(str(value))))
|
||||||
|
|
||||||
def set_param(self, params, value=None):
|
def set_param(self, params, value=None):
|
||||||
"""Set parameters into the Booster.
|
"""Set parameters into the Booster.
|
||||||
|
|
||||||
|
|||||||
188
python-package/xgboost/rabit.py
Normal file
188
python-package/xgboost/rabit.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
"""Distributed XGBoost Rabit related API."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
import sys
|
||||||
|
import atexit
|
||||||
|
import ctypes
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .core import _LIB, c_str, STRING_TYPES
|
||||||
|
|
||||||
|
def _init_rabit():
|
||||||
|
"""Initialize the rabit library."""
|
||||||
|
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||||
|
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
||||||
|
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||||
|
_LIB.RabitInit(0, None)
|
||||||
|
|
||||||
|
|
||||||
|
def finalize():
|
||||||
|
"""Finalize the process, notify tracker everything is done."""
|
||||||
|
_LIB.RabitFinalize()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
"""Get rank of current process.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
rank : int
|
||||||
|
Rank of current process.
|
||||||
|
"""
|
||||||
|
ret = _LIB.RabitGetRank()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
"""Get total number workers.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
n : int
|
||||||
|
Total number of process.
|
||||||
|
"""
|
||||||
|
ret = _LIB.RabitGetWorldSize()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def tracker_print(msg):
|
||||||
|
"""Print message to the tracker.
|
||||||
|
|
||||||
|
This function can be used to communicate the information of
|
||||||
|
the progress to the tracker
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
msg : str
|
||||||
|
The message to be printed to tracker.
|
||||||
|
"""
|
||||||
|
if not isinstance(msg, STRING_TYPES):
|
||||||
|
msg = str(msg)
|
||||||
|
_LIB.RabitTrackerPrint(c_str(msg))
|
||||||
|
|
||||||
|
|
||||||
|
def get_processor_name():
|
||||||
|
"""Get the processor name.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
name : str
|
||||||
|
the name of processor(host)
|
||||||
|
"""
|
||||||
|
mxlen = 256
|
||||||
|
length = ctypes.c_ulong()
|
||||||
|
buf = ctypes.create_string_buffer(mxlen)
|
||||||
|
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
||||||
|
return buf.value
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast(data, root):
|
||||||
|
"""Broadcast object from one node to all other nodes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : any type that can be pickled
|
||||||
|
Input data, if current rank does not equal root, this can be None
|
||||||
|
root : int
|
||||||
|
Rank of the node to broadcast data from.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
object : int
|
||||||
|
the result of broadcast.
|
||||||
|
"""
|
||||||
|
rank = get_rank()
|
||||||
|
length = ctypes.c_ulong()
|
||||||
|
if root == rank:
|
||||||
|
assert data is not None, 'need to pass in data when broadcasting'
|
||||||
|
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
length.value = len(s)
|
||||||
|
# run first broadcast
|
||||||
|
_LIB.RabitBroadcast(ctypes.byref(length),
|
||||||
|
ctypes.sizeof(ctypes.c_ulong), root)
|
||||||
|
if root != rank:
|
||||||
|
dptr = (ctypes.c_char * length.value)()
|
||||||
|
# run second
|
||||||
|
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
||||||
|
length.value, root)
|
||||||
|
data = pickle.loads(dptr.raw)
|
||||||
|
del dptr
|
||||||
|
else:
|
||||||
|
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
||||||
|
length.value, root)
|
||||||
|
del s
|
||||||
|
return data
|
||||||
|
|
||||||
|
# enumeration of dtypes
|
||||||
|
DTYPE_ENUM__ = {
|
||||||
|
np.dtype('int8') : 0,
|
||||||
|
np.dtype('uint8') : 1,
|
||||||
|
np.dtype('int32') : 2,
|
||||||
|
np.dtype('uint32') : 3,
|
||||||
|
np.dtype('int64') : 4,
|
||||||
|
np.dtype('uint64') : 5,
|
||||||
|
np.dtype('float32') : 6,
|
||||||
|
np.dtype('float64') : 7
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def allreduce(data, op, prepare_fun=None):
|
||||||
|
"""Perform allreduce, return the result.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data: numpy array
|
||||||
|
Input data.
|
||||||
|
op: int
|
||||||
|
Reduction operators, can be MIN, MAX, SUM, BITOR
|
||||||
|
prepare_fun: function
|
||||||
|
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||||
|
will be called by the function before performing allreduce, to intialize the data
|
||||||
|
If the result of Allreduce can be recovered directly,
|
||||||
|
then prepare_fun will NOT be called
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
result : array_like
|
||||||
|
The result of allreduce, have same shape as data
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
This function is not thread-safe.
|
||||||
|
"""
|
||||||
|
if not isinstance(data, np.ndarray):
|
||||||
|
raise Exception('allreduce only takes in numpy.ndarray')
|
||||||
|
buf = data.ravel()
|
||||||
|
if buf.base is data.base:
|
||||||
|
buf = buf.copy()
|
||||||
|
if buf.dtype not in DTYPE_ENUM__:
|
||||||
|
raise Exception('data type %s not supported' % str(buf.dtype))
|
||||||
|
if prepare_fun is None:
|
||||||
|
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
|
op, None, None)
|
||||||
|
else:
|
||||||
|
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||||
|
def pfunc(args):
|
||||||
|
"""prepare function."""
|
||||||
|
prepare_fun(data)
|
||||||
|
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
|
op, func_ptr(pfunc), None)
|
||||||
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
def version_number():
|
||||||
|
"""Returns version number of current stored model.
|
||||||
|
|
||||||
|
This means how many calls to CheckPoint we made so far.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
version : int
|
||||||
|
Version number of currently stored model
|
||||||
|
"""
|
||||||
|
ret = _LIB.RabitVersionNumber()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
# intialization script
|
||||||
|
_init_rabit()
|
||||||
@ -6,9 +6,11 @@ from __future__ import absolute_import
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, STRING_TYPES
|
from .core import Booster, STRING_TYPES
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold, XGBKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold, XGBKFold)
|
||||||
|
from . import rabit
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||||
maximize=False, early_stopping_rounds=None, evals_result=None,
|
maximize=False, early_stopping_rounds=None, evals_result=None,
|
||||||
@ -94,6 +96,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
verbose_eval_every_line = verbose_eval
|
verbose_eval_every_line = verbose_eval
|
||||||
verbose_eval = True if verbose_eval_every_line > 0 else False
|
verbose_eval = True if verbose_eval_every_line > 0 else False
|
||||||
|
|
||||||
|
if rabit.get_rank() != 0:
|
||||||
|
verbose_eval = False;
|
||||||
|
|
||||||
if xgb_model is not None:
|
if xgb_model is not None:
|
||||||
if not isinstance(xgb_model, STRING_TYPES):
|
if not isinstance(xgb_model, STRING_TYPES):
|
||||||
xgb_model = xgb_model.save_raw()
|
xgb_model = xgb_model.save_raw()
|
||||||
@ -123,15 +128,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
raise ValueError('For early stopping you need at least one set in evals.')
|
raise ValueError('For early stopping you need at least one set in evals.')
|
||||||
|
|
||||||
if verbose_eval:
|
if verbose_eval:
|
||||||
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
||||||
evals[-1][1], early_stopping_rounds))
|
evals[-1][1], early_stopping_rounds))
|
||||||
|
|
||||||
# is params a list of tuples? are we using multiple eval metrics?
|
# is params a list of tuples? are we using multiple eval metrics?
|
||||||
if isinstance(params, list):
|
if isinstance(params, list):
|
||||||
if len(params) != len(dict(params).items()):
|
if len(params) != len(dict(params).items()):
|
||||||
params = dict(params)
|
params = dict(params)
|
||||||
sys.stderr.write("Multiple eval metrics have been passed: " \
|
rabit.tracker_print("Multiple eval metrics have been passed: " \
|
||||||
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
|
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
|
||||||
else:
|
else:
|
||||||
params = dict(params)
|
params = dict(params)
|
||||||
|
|
||||||
@ -145,23 +150,35 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
maximize_score = maximize
|
maximize_score = maximize
|
||||||
|
|
||||||
if maximize_score:
|
if maximize_score:
|
||||||
best_score = 0.0
|
bst.set_attr(best_score='0.0')
|
||||||
else:
|
else:
|
||||||
best_score = float('inf')
|
bst.set_attr(best_score='inf')
|
||||||
|
bst.set_attr(best_iteration='0')
|
||||||
best_msg = ''
|
|
||||||
best_score_i = (nboost - 1)
|
|
||||||
|
|
||||||
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
||||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||||
|
|
||||||
for i in range(nboost, nboost + num_boost_round):
|
# Distributed code: Load the checkpoint from rabit.
|
||||||
|
version = bst.load_rabit_checkpoint()
|
||||||
|
assert(rabit.get_world_size() != 1 or version == 0)
|
||||||
|
start_iteration = int(version / 2)
|
||||||
|
nboost += start_iteration
|
||||||
|
|
||||||
|
for i in range(start_iteration, num_boost_round):
|
||||||
if learning_rates is not None:
|
if learning_rates is not None:
|
||||||
if isinstance(learning_rates, list):
|
if isinstance(learning_rates, list):
|
||||||
bst.set_param({'eta': learning_rates[i]})
|
bst.set_param({'eta': learning_rates[i]})
|
||||||
else:
|
else:
|
||||||
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
||||||
bst.update(dtrain, i, obj)
|
|
||||||
|
# Distributed code: need to resume to this point.
|
||||||
|
# Skip the first update if it is a recovery step.
|
||||||
|
if version % 2 == 0:
|
||||||
|
bst.update(dtrain, i, obj)
|
||||||
|
bst.save_rabit_checkpoint()
|
||||||
|
version += 1
|
||||||
|
|
||||||
|
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
|
||||||
|
|
||||||
nboost += 1
|
nboost += 1
|
||||||
# check evaluation result.
|
# check evaluation result.
|
||||||
@ -176,9 +193,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
if verbose_eval:
|
if verbose_eval:
|
||||||
if verbose_eval_every_line:
|
if verbose_eval_every_line:
|
||||||
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
|
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
|
||||||
sys.stderr.write(msg + '\n')
|
rabit.tracker_print(msg + '\n')
|
||||||
else:
|
else:
|
||||||
sys.stderr.write(msg + '\n')
|
rabit.tracker_print(msg + '\n')
|
||||||
|
|
||||||
if evals_result is not None:
|
if evals_result is not None:
|
||||||
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
||||||
@ -196,22 +213,26 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
|
|
||||||
if early_stopping_rounds:
|
if early_stopping_rounds:
|
||||||
score = float(msg.rsplit(':', 1)[1])
|
score = float(msg.rsplit(':', 1)[1])
|
||||||
|
best_score = float(bst.attr('best_score'))
|
||||||
|
best_iteration = int(bst.attr('best_iteration'))
|
||||||
if (maximize_score and score > best_score) or \
|
if (maximize_score and score > best_score) or \
|
||||||
(not maximize_score and score < best_score):
|
(not maximize_score and score < best_score):
|
||||||
best_score = score
|
# save the property to attributes, so they will occur in checkpoint.
|
||||||
best_score_i = (nboost - 1)
|
bst.set_attr(best_score=str(score),
|
||||||
best_msg = msg
|
best_iteration=str(nboost - 1),
|
||||||
elif i - best_score_i >= early_stopping_rounds:
|
best_msg=msg)
|
||||||
|
elif i - best_iteration >= early_stopping_rounds:
|
||||||
|
best_msg = bst.attr('best_msg')
|
||||||
if verbose_eval:
|
if verbose_eval:
|
||||||
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
rabit.tracker_print("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||||
# best iteration will be assigned in the end.
|
|
||||||
bst.best_score = best_score
|
|
||||||
bst.best_iteration = best_score_i
|
|
||||||
break
|
break
|
||||||
|
# do checkpoint after evaluation, in case evaluation also updates booster.
|
||||||
|
bst.save_rabit_checkpoint()
|
||||||
|
version += 1
|
||||||
|
|
||||||
if early_stopping_rounds:
|
if early_stopping_rounds:
|
||||||
best_score = best_score
|
bst.best_score = float(bst.attr('best_score'))
|
||||||
bst.best_iteration = best_score_i
|
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||||
else:
|
else:
|
||||||
bst.best_iteration = nboost - 1
|
bst.best_iteration = nboost - 1
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit 56ec4263f9a70a315c1f153dc5897b7c1b58250c
|
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
|
||||||
@ -91,7 +91,7 @@ int XGDMatrixCreateFromFile(const char *fname,
|
|||||||
<< "will split data among workers";
|
<< "will split data among workers";
|
||||||
}
|
}
|
||||||
*out = DMatrix::Load(
|
*out = DMatrix::Load(
|
||||||
fname, silent != 0, false);
|
fname, false, true);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -533,18 +533,44 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int XGBoosterGetAttr(BoosterHandle handle,
|
||||||
|
const char* key,
|
||||||
|
const char** out,
|
||||||
|
int* success) {
|
||||||
|
Booster* bst = static_cast<Booster*>(handle);
|
||||||
|
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||||
|
API_BEGIN();
|
||||||
|
if (bst->learner()->GetAttr(key, &ret_str)) {
|
||||||
|
*out = ret_str.c_str();
|
||||||
|
*success = 1;
|
||||||
|
} else {
|
||||||
|
*out = nullptr;
|
||||||
|
*success = 0;
|
||||||
|
}
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
int XGBoosterSetAttr(BoosterHandle handle,
|
||||||
|
const char* key,
|
||||||
|
const char* value) {
|
||||||
|
Booster* bst = static_cast<Booster*>(handle);
|
||||||
|
API_BEGIN();
|
||||||
|
bst->learner()->SetAttr(key, value);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||||
int* version) {
|
int* version) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
Booster* bst = static_cast<Booster*>(handle);
|
Booster* bst = static_cast<Booster*>(handle);
|
||||||
*version = rabit::LoadCheckPoint(bst->learner());
|
*version = rabit::LoadCheckPoint(bst->learner());
|
||||||
if (version != 0) {
|
if (*version != 0) {
|
||||||
bst->initialized_ = true;
|
bst->initialized_ = true;
|
||||||
}
|
}
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
|
int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
Booster* bst = static_cast<Booster*>(handle);
|
Booster* bst = static_cast<Booster*>(handle);
|
||||||
if (bst->learner()->AllowLazyCheckPoint()) {
|
if (bst->learner()->AllowLazyCheckPoint()) {
|
||||||
|
|||||||
@ -184,7 +184,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
<< " of " << npart << " parts";
|
<< " of " << npart << " parts";
|
||||||
}
|
}
|
||||||
// legacy handling of binary data loading
|
// legacy handling of binary data loading
|
||||||
if (file_format == "auto" && !load_row_split) {
|
if (file_format == "auto" && npart == 1) {
|
||||||
int magic;
|
int magic;
|
||||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
||||||
if (fi.get() != nullptr) {
|
if (fi.get() != nullptr) {
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include <xgboost/learner.h>
|
#include <xgboost/learner.h>
|
||||||
|
#include <dmlc/io.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -43,8 +44,10 @@ struct LearnerModelParam
|
|||||||
unsigned num_feature;
|
unsigned num_feature;
|
||||||
/* \brief number of classes, if it is multi-class classification */
|
/* \brief number of classes, if it is multi-class classification */
|
||||||
int num_class;
|
int num_class;
|
||||||
|
/*! \brief Model contain additional properties */
|
||||||
|
int contain_extra_attrs;
|
||||||
/*! \brief reserved field */
|
/*! \brief reserved field */
|
||||||
int reserved[31];
|
int reserved[30];
|
||||||
/*! \brief constructor */
|
/*! \brief constructor */
|
||||||
LearnerModelParam() {
|
LearnerModelParam() {
|
||||||
std::memset(this, 0, sizeof(LearnerModelParam));
|
std::memset(this, 0, sizeof(LearnerModelParam));
|
||||||
@ -243,6 +246,12 @@ class LearnerImpl : public Learner {
|
|||||||
obj_.reset(ObjFunction::Create(name_obj_));
|
obj_.reset(ObjFunction::Create(name_obj_));
|
||||||
gbm_.reset(GradientBooster::Create(name_gbm_));
|
gbm_.reset(GradientBooster::Create(name_gbm_));
|
||||||
gbm_->Load(fi);
|
gbm_->Load(fi);
|
||||||
|
if (mparam.contain_extra_attrs != 0) {
|
||||||
|
std::vector<std::pair<std::string, std::string> > attr;
|
||||||
|
fi->Read(&attr);
|
||||||
|
attributes_ = std::map<std::string, std::string>(
|
||||||
|
attr.begin(), attr.end());
|
||||||
|
}
|
||||||
|
|
||||||
if (metrics_.size() == 0) {
|
if (metrics_.size() == 0) {
|
||||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||||
@ -259,6 +268,11 @@ class LearnerImpl : public Learner {
|
|||||||
fo->Write(name_obj_);
|
fo->Write(name_obj_);
|
||||||
fo->Write(name_gbm_);
|
fo->Write(name_gbm_);
|
||||||
gbm_->Save(fo);
|
gbm_->Save(fo);
|
||||||
|
if (mparam.contain_extra_attrs != 0) {
|
||||||
|
std::vector<std::pair<std::string, std::string> > attr(
|
||||||
|
attributes_.begin(), attributes_.end());
|
||||||
|
fo->Write(attr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||||
@ -300,6 +314,18 @@ class LearnerImpl : public Learner {
|
|||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetAttr(const std::string& key, const std::string& value) override {
|
||||||
|
attributes_[key] = value;
|
||||||
|
mparam.contain_extra_attrs = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetAttr(const std::string& key, std::string* out) const override {
|
||||||
|
auto it = attributes_.find(key);
|
||||||
|
if (it == attributes_.end()) return false;
|
||||||
|
*out = it->second;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
|
std::pair<std::string, float> Evaluate(DMatrix* data, std::string metric) {
|
||||||
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
||||||
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
||||||
@ -427,6 +453,8 @@ class LearnerImpl : public Learner {
|
|||||||
LearnerTrainParam tparam;
|
LearnerTrainParam tparam;
|
||||||
// configurations
|
// configurations
|
||||||
std::map<std::string, std::string> cfg_;
|
std::map<std::string, std::string> cfg_;
|
||||||
|
// attributes
|
||||||
|
std::map<std::string, std::string> attributes_;
|
||||||
// name of gbm
|
// name of gbm
|
||||||
std::string name_gbm_;
|
std::string name_gbm_;
|
||||||
// name of objective functon
|
// name of objective functon
|
||||||
|
|||||||
@ -56,6 +56,9 @@ class BaseMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
/*! \brief synchronize the information */
|
||||||
|
inline void SyncInfo() {
|
||||||
rabit::Allreduce<rabit::op::Max>(dmlc::BeginPtr(fminmax), fminmax.size());
|
rabit::Allreduce<rabit::op::Max>(dmlc::BeginPtr(fminmax), fminmax.size());
|
||||||
}
|
}
|
||||||
// get feature type, 0:empty 1:binary 2:real
|
// get feature type, 0:empty 1:binary 2:real
|
||||||
|
|||||||
@ -313,6 +313,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
feat_helper.InitByCol(p_fmat, tree);
|
feat_helper.InitByCol(p_fmat, tree);
|
||||||
cache_dmatrix_ = p_fmat;
|
cache_dmatrix_ = p_fmat;
|
||||||
}
|
}
|
||||||
|
feat_helper.SyncInfo();
|
||||||
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
|
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
|
||||||
}
|
}
|
||||||
// code to create histogram
|
// code to create histogram
|
||||||
|
|||||||
4
tests/distributed/runtests.sh
Executable file
4
tests/distributed/runtests.sh
Executable file
@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
PYTHONPATH=../../python-package/ ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3\
|
||||||
|
python test_basic.py
|
||||||
29
tests/distributed/test_basic.py
Normal file
29
tests/distributed/test_basic.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
import numpy as np
|
||||||
|
import scipy.sparse
|
||||||
|
import pickle
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
# Load file, file will be automatically sharded in distributed mode.
|
||||||
|
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||||
|
|
||||||
|
# specify parameters via map, definition are same as c++ version
|
||||||
|
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||||
|
|
||||||
|
# specify validations set to watch performance
|
||||||
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
|
num_round = 20
|
||||||
|
|
||||||
|
# Run training, all the features in training API is available.
|
||||||
|
# Currently, this script only support calling train once for fault recovery purpose.
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||||
|
|
||||||
|
# save the model, only ask process 0 to save the model.
|
||||||
|
if xgb.rabit.get_rank() == 0:
|
||||||
|
bst.save_model("test.model")
|
||||||
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
|
# Notify the tracker all training has been successful
|
||||||
|
# This is only needed in distributed training.
|
||||||
|
xgb.rabit.finalize()
|
||||||
Loading…
x
Reference in New Issue
Block a user