[PYTHON-DIST] Distributed xgboost python training API.
This commit is contained in:
@@ -10,6 +10,7 @@ import os
|
||||
|
||||
from .core import DMatrix, Booster
|
||||
from .training import train, cv
|
||||
from . import rabit
|
||||
try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
|
||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||
|
||||
@@ -12,9 +12,11 @@ PY3 = (sys.version_info[0] == 3)
|
||||
if PY3:
|
||||
# pylint: disable=invalid-name, redefined-builtin
|
||||
STRING_TYPES = str,
|
||||
py_str = lambda x: x.decode('utf-8')
|
||||
else:
|
||||
# pylint: disable=invalid-name
|
||||
STRING_TYPES = basestring,
|
||||
py_str = lambda x: x
|
||||
|
||||
# pandas
|
||||
try:
|
||||
|
||||
@@ -12,7 +12,7 @@ import scipy.sparse
|
||||
|
||||
from .libpath import find_lib_path
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, py_str
|
||||
|
||||
class XGBoostError(Exception):
|
||||
"""Error throwed by xgboost trainer."""
|
||||
@@ -654,10 +654,63 @@ class Booster(object):
|
||||
Returns
|
||||
-------
|
||||
booster: `Booster`
|
||||
a copied booster model
|
||||
a copied booster model
|
||||
"""
|
||||
return self.__copy__()
|
||||
|
||||
def load_rabit_checkpoint(self):
|
||||
"""Initialize the model by load from rabit checkpoint.
|
||||
|
||||
Returns
|
||||
-------
|
||||
version: integer
|
||||
The version number of the model.
|
||||
"""
|
||||
version = ctypes.c_int()
|
||||
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
|
||||
self.handle, ctypes.byref(version)))
|
||||
return version.value
|
||||
|
||||
def save_rabit_checkpoint(self):
|
||||
"""Save the current booster to rabit checkpoint."""
|
||||
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))
|
||||
|
||||
def attr(self, key):
|
||||
"""Get attribute string from the Booster.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The key to get attribute from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
value : str
|
||||
The attribute value of the key, returns None if attribute do not exist.
|
||||
"""
|
||||
ret = ctypes.c_char_p()
|
||||
success = ctypes.c_int()
|
||||
_check_call(_LIB.XGBoosterGetAttr(
|
||||
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
|
||||
if success.value != 0:
|
||||
return py_str(ret.value)
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_attr(self, **kwargs):
|
||||
"""Set the attribute of the Booster.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs
|
||||
The attributes to set
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not isinstance(value, STRING_TYPES):
|
||||
raise ValueError("Set Attr only accepts string values")
|
||||
_check_call(_LIB.XGBoosterSetAttr(
|
||||
self.handle, c_str(key), c_str(str(value))))
|
||||
|
||||
def set_param(self, params, value=None):
|
||||
"""Set parameters into the Booster.
|
||||
|
||||
|
||||
188
python-package/xgboost/rabit.py
Normal file
188
python-package/xgboost/rabit.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Distributed XGBoost Rabit related API."""
|
||||
from __future__ import absolute_import
|
||||
import sys
|
||||
import atexit
|
||||
import ctypes
|
||||
import numpy as np
|
||||
|
||||
from .core import _LIB, c_str, STRING_TYPES
|
||||
|
||||
def _init_rabit():
|
||||
"""Initialize the rabit library."""
|
||||
_LIB.RabitGetRank.restype = ctypes.c_int
|
||||
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
||||
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
||||
_LIB.RabitInit(0, None)
|
||||
|
||||
|
||||
def finalize():
|
||||
"""Finalize the process, notify tracker everything is done."""
|
||||
_LIB.RabitFinalize()
|
||||
|
||||
|
||||
def get_rank():
|
||||
"""Get rank of current process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rank : int
|
||||
Rank of current process.
|
||||
"""
|
||||
ret = _LIB.RabitGetRank()
|
||||
return ret
|
||||
|
||||
|
||||
def get_world_size():
|
||||
"""Get total number workers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
n : int
|
||||
Total number of process.
|
||||
"""
|
||||
ret = _LIB.RabitGetWorldSize()
|
||||
return ret
|
||||
|
||||
|
||||
def tracker_print(msg):
|
||||
"""Print message to the tracker.
|
||||
|
||||
This function can be used to communicate the information of
|
||||
the progress to the tracker
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : str
|
||||
The message to be printed to tracker.
|
||||
"""
|
||||
if not isinstance(msg, STRING_TYPES):
|
||||
msg = str(msg)
|
||||
_LIB.RabitTrackerPrint(c_str(msg))
|
||||
|
||||
|
||||
def get_processor_name():
|
||||
"""Get the processor name.
|
||||
|
||||
Returns
|
||||
-------
|
||||
name : str
|
||||
the name of processor(host)
|
||||
"""
|
||||
mxlen = 256
|
||||
length = ctypes.c_ulong()
|
||||
buf = ctypes.create_string_buffer(mxlen)
|
||||
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
||||
return buf.value
|
||||
|
||||
|
||||
def broadcast(data, root):
|
||||
"""Broadcast object from one node to all other nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : any type that can be pickled
|
||||
Input data, if current rank does not equal root, this can be None
|
||||
root : int
|
||||
Rank of the node to broadcast data from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object : int
|
||||
the result of broadcast.
|
||||
"""
|
||||
rank = get_rank()
|
||||
length = ctypes.c_ulong()
|
||||
if root == rank:
|
||||
assert data is not None, 'need to pass in data when broadcasting'
|
||||
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
length.value = len(s)
|
||||
# run first broadcast
|
||||
_LIB.RabitBroadcast(ctypes.byref(length),
|
||||
ctypes.sizeof(ctypes.c_ulong), root)
|
||||
if root != rank:
|
||||
dptr = (ctypes.c_char * length.value)()
|
||||
# run second
|
||||
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
||||
length.value, root)
|
||||
data = pickle.loads(dptr.raw)
|
||||
del dptr
|
||||
else:
|
||||
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
||||
length.value, root)
|
||||
del s
|
||||
return data
|
||||
|
||||
# enumeration of dtypes
|
||||
DTYPE_ENUM__ = {
|
||||
np.dtype('int8') : 0,
|
||||
np.dtype('uint8') : 1,
|
||||
np.dtype('int32') : 2,
|
||||
np.dtype('uint32') : 3,
|
||||
np.dtype('int64') : 4,
|
||||
np.dtype('uint64') : 5,
|
||||
np.dtype('float32') : 6,
|
||||
np.dtype('float64') : 7
|
||||
}
|
||||
|
||||
|
||||
def allreduce(data, op, prepare_fun=None):
|
||||
"""Perform allreduce, return the result.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: numpy array
|
||||
Input data.
|
||||
op: int
|
||||
Reduction operators, can be MIN, MAX, SUM, BITOR
|
||||
prepare_fun: function
|
||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||
will be called by the function before performing allreduce, to intialize the data
|
||||
If the result of Allreduce can be recovered directly,
|
||||
then prepare_fun will NOT be called
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : array_like
|
||||
The result of allreduce, have same shape as data
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is not thread-safe.
|
||||
"""
|
||||
if not isinstance(data, np.ndarray):
|
||||
raise Exception('allreduce only takes in numpy.ndarray')
|
||||
buf = data.ravel()
|
||||
if buf.base is data.base:
|
||||
buf = buf.copy()
|
||||
if buf.dtype not in DTYPE_ENUM__:
|
||||
raise Exception('data type %s not supported' % str(buf.dtype))
|
||||
if prepare_fun is None:
|
||||
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, None, None)
|
||||
else:
|
||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
def pfunc(args):
|
||||
"""prepare function."""
|
||||
prepare_fun(data)
|
||||
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, func_ptr(pfunc), None)
|
||||
return buf
|
||||
|
||||
|
||||
def version_number():
|
||||
"""Returns version number of current stored model.
|
||||
|
||||
This means how many calls to CheckPoint we made so far.
|
||||
|
||||
Returns
|
||||
-------
|
||||
version : int
|
||||
Version number of currently stored model
|
||||
"""
|
||||
ret = _LIB.RabitVersionNumber()
|
||||
return ret
|
||||
|
||||
# intialization script
|
||||
_init_rabit()
|
||||
@@ -6,9 +6,11 @@ from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
from .core import Booster, STRING_TYPES
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold, XGBKFold)
|
||||
from . import rabit
|
||||
|
||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
maximize=False, early_stopping_rounds=None, evals_result=None,
|
||||
@@ -94,6 +96,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
verbose_eval_every_line = verbose_eval
|
||||
verbose_eval = True if verbose_eval_every_line > 0 else False
|
||||
|
||||
if rabit.get_rank() != 0:
|
||||
verbose_eval = False;
|
||||
|
||||
if xgb_model is not None:
|
||||
if not isinstance(xgb_model, STRING_TYPES):
|
||||
xgb_model = xgb_model.save_raw()
|
||||
@@ -123,15 +128,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
raise ValueError('For early stopping you need at least one set in evals.')
|
||||
|
||||
if verbose_eval:
|
||||
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
||||
rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
||||
evals[-1][1], early_stopping_rounds))
|
||||
|
||||
# is params a list of tuples? are we using multiple eval metrics?
|
||||
if isinstance(params, list):
|
||||
if len(params) != len(dict(params).items()):
|
||||
params = dict(params)
|
||||
sys.stderr.write("Multiple eval metrics have been passed: " \
|
||||
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
|
||||
rabit.tracker_print("Multiple eval metrics have been passed: " \
|
||||
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
|
||||
else:
|
||||
params = dict(params)
|
||||
|
||||
@@ -145,23 +150,35 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
maximize_score = maximize
|
||||
|
||||
if maximize_score:
|
||||
best_score = 0.0
|
||||
bst.set_attr(best_score='0.0')
|
||||
else:
|
||||
best_score = float('inf')
|
||||
|
||||
best_msg = ''
|
||||
best_score_i = (nboost - 1)
|
||||
bst.set_attr(best_score='inf')
|
||||
bst.set_attr(best_iteration='0')
|
||||
|
||||
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||
|
||||
for i in range(nboost, nboost + num_boost_round):
|
||||
# Distributed code: Load the checkpoint from rabit.
|
||||
version = bst.load_rabit_checkpoint()
|
||||
assert(rabit.get_world_size() != 1 or version == 0)
|
||||
start_iteration = int(version / 2)
|
||||
nboost += start_iteration
|
||||
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
if learning_rates is not None:
|
||||
if isinstance(learning_rates, list):
|
||||
bst.set_param({'eta': learning_rates[i]})
|
||||
else:
|
||||
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
||||
bst.update(dtrain, i, obj)
|
||||
|
||||
# Distributed code: need to resume to this point.
|
||||
# Skip the first update if it is a recovery step.
|
||||
if version % 2 == 0:
|
||||
bst.update(dtrain, i, obj)
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
|
||||
|
||||
nboost += 1
|
||||
# check evaluation result.
|
||||
@@ -176,9 +193,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
if verbose_eval:
|
||||
if verbose_eval_every_line:
|
||||
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
|
||||
sys.stderr.write(msg + '\n')
|
||||
rabit.tracker_print(msg + '\n')
|
||||
else:
|
||||
sys.stderr.write(msg + '\n')
|
||||
rabit.tracker_print(msg + '\n')
|
||||
|
||||
if evals_result is not None:
|
||||
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
||||
@@ -196,22 +213,26 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
|
||||
if early_stopping_rounds:
|
||||
score = float(msg.rsplit(':', 1)[1])
|
||||
best_score = float(bst.attr('best_score'))
|
||||
best_iteration = int(bst.attr('best_iteration'))
|
||||
if (maximize_score and score > best_score) or \
|
||||
(not maximize_score and score < best_score):
|
||||
best_score = score
|
||||
best_score_i = (nboost - 1)
|
||||
best_msg = msg
|
||||
elif i - best_score_i >= early_stopping_rounds:
|
||||
# save the property to attributes, so they will occur in checkpoint.
|
||||
bst.set_attr(best_score=str(score),
|
||||
best_iteration=str(nboost - 1),
|
||||
best_msg=msg)
|
||||
elif i - best_iteration >= early_stopping_rounds:
|
||||
best_msg = bst.attr('best_msg')
|
||||
if verbose_eval:
|
||||
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||
# best iteration will be assigned in the end.
|
||||
bst.best_score = best_score
|
||||
bst.best_iteration = best_score_i
|
||||
rabit.tracker_print("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||
break
|
||||
# do checkpoint after evaluation, in case evaluation also updates booster.
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
if early_stopping_rounds:
|
||||
best_score = best_score
|
||||
bst.best_iteration = best_score_i
|
||||
bst.best_score = float(bst.attr('best_score'))
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
else:
|
||||
bst.best_iteration = nboost - 1
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
|
||||
Reference in New Issue
Block a user