* support run rabit tests as xgboost subproject using xgboost/dmlc-core * support tracker config set/get * remove redudant printf * remove redudant printf * add c++0x declaration * log allreduce/broadcast caller, engine should track caller stack for investigation * tracker support binary config format * Revert "tracker support binary config format" This reverts commit 2a28e5e2b55c200cb621af8d19f17ab1bc62503b. * remove caller, prototype fetch allreduce/broadcast results from resbuf * store cached allreduce/broadcast seq_no to tracker * allow restore all caches from other nodes * try new rabit collective cache, todo: recv_link seems down * link up cache restore with main recovery * cleanup load cache state * update cache api * pass test.mk * have a working tests * try to unify check into actionsummary * more logging to debug distributed hist three method issue * update rabit interface to support caller signature matching * splite seq_counter from cur_cache_seq to different variables * still see issue with inf loop * support debug print caller as well as allreduce op * cleanup * remove get/set cache from model_recover, adding recover in loadcheckpoint * clarify rabit cache strategy, cache is set only by successful collective call involving all nodes with unique cache key. if all nodes call getcache at same time, we keep rabit run collective call. If some nodes call getcache while others not, we backfill cache from those nodes with most entries * revert caller logs * fix lint error * fix engine mpi signature * support getcache by ref * allow result buffer presiet to filestream * add loging * try fix checkpoint failure recovery case * use int64_t to avoid overflow caused seq fault * try avoid int overflow * try fix checkpoint failure recovery case * try avoid seqno overflow to negative by offseting specifial flag value adding cache seq no to checkpoint/load checkpoint/check point ack to avoid confusion from cache recovery * fix cache seq assert error * remove loging, handle edge case * add extensive log to checkpoint state with different seq no * fix lint errors * clean up comments before merge back to master * add logs to allreduce/broadcast/checkpoint * use unsinged int 32 and give seq no larger range * address remove allreduce dropseq code segment * using caller signature to filter bootstrapallreduces * remove get/set cache from empty * apply signature to reducer * apply signature to broadcast * add key to broadcat log * fix broadcast signature * fix default _line value for non linux system * adding comments, remove sleep(1) * fix osx build issue * try fix mpi * fix doc * fix engine_empty api * logging, adding more logs, restore immutable assertion * print unsinged int with ud * fix lint * rename seqtype to kSeq and KCache indicating it's usage apply kDiffSeq check to load_cache routine * comment allreduce/broadcast log * allow tests run on arm * enable flag to turn on / off cache * add log info alert if user choose to enable rabit bootstrap cache * add rabit_debug setting so user can use config to turn on * log flags when user turn on rabit_debug * force rabit restart if tracker assign -1 rank * use OPENMP to vecotrize reducer * address comment * Revert "address comment" This reverts commit 1dc61f33e7357dad8fa65528abeb81db92c5f9ed. * fix checkpoint size print 0 * per feedback, remove DISABLEOPEMP, address race condition * - remove openmp from this pr - update name from cache to boostrapcache * add default value of signature macros * remove openmp from cmake file * Update src/allreduce_robust.cc Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu> * Update src/allreduce_robust.cc Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu> * run test with cmake * remove openmp * fix cmake based tests * use cmake test fix darwin .dylib issue * move around rabit_signature definition due to windows build * misc, add c++ check in CMakeFile * per feedback * resolve CMake file * update rabit version
365 lines
10 KiB
Python
365 lines
10 KiB
Python
"""
|
|
Reliable Allreduce and Broadcast Library.
|
|
|
|
Author: Tianqi Chen
|
|
"""
|
|
# pylint: disable=unused-argument,invalid-name,global-statement,dangerous-default-value,
|
|
import pickle
|
|
import ctypes
|
|
import os
|
|
import platform
|
|
import sys
|
|
import warnings
|
|
import numpy as np
|
|
|
|
# version information about the doc
|
|
__version__ = '1.0'
|
|
|
|
_LIB = None
|
|
|
|
def _find_lib_path(dll_name):
|
|
"""Find the rabit dynamic library files.
|
|
|
|
Returns
|
|
-------
|
|
lib_path: list(string)
|
|
List of all found library path to rabit
|
|
"""
|
|
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
|
|
# make pythonpack hack: copy this directory one level upper for setup.py
|
|
dll_path = [curr_path,
|
|
os.path.join(curr_path, '../lib/'),
|
|
os.path.join(curr_path, './lib/')]
|
|
if os.name == 'nt':
|
|
dll_path = [os.path.join(p, dll_name) for p in dll_path]
|
|
else:
|
|
dll_path = [os.path.join(p, dll_name) for p in dll_path]
|
|
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
|
|
#From github issues, most of installation errors come from machines w/o compilers
|
|
if len(lib_path) == 0 and not os.environ.get('XGBOOST_BUILD_DOC', False):
|
|
raise RuntimeError(
|
|
'Cannot find Rabit Libarary in the candicate path, ' +
|
|
'did you install compilers and run build.sh in root path?\n'
|
|
'List of candidates:\n' + ('\n'.join(dll_path)))
|
|
return lib_path
|
|
|
|
# load in xgboost library
|
|
def _loadlib(lib='standard', lib_dll=None):
|
|
"""Load rabit library."""
|
|
global _LIB
|
|
if _LIB is not None:
|
|
warnings.warn('rabit.int call was ignored because it has'\
|
|
' already been initialized', level=2)
|
|
return
|
|
|
|
if lib_dll is not None:
|
|
_LIB = lib_dll
|
|
return
|
|
|
|
if lib == 'standard':
|
|
dll_name = 'librabit'
|
|
else:
|
|
dll_name = 'librabit_' + lib
|
|
|
|
if os.name == 'nt':
|
|
dll_name += '.dll'
|
|
elif platform.system() == 'Darwin':
|
|
dll_name += '.dylib'
|
|
else:
|
|
dll_name += '.so'
|
|
|
|
_LIB = ctypes.cdll.LoadLibrary(_find_lib_path(dll_name)[0])
|
|
_LIB.RabitGetRank.restype = ctypes.c_int
|
|
_LIB.RabitGetWorldSize.restype = ctypes.c_int
|
|
_LIB.RabitVersionNumber.restype = ctypes.c_int
|
|
|
|
def _unloadlib():
|
|
"""Unload rabit library."""
|
|
global _LIB
|
|
del _LIB
|
|
_LIB = None
|
|
|
|
# reduction operators
|
|
MAX = 0
|
|
MIN = 1
|
|
SUM = 2
|
|
BITOR = 3
|
|
|
|
def init(args=None, lib='standard', lib_dll=None):
|
|
"""Intialize the rabit module, call this once before using anything.
|
|
|
|
Parameters
|
|
----------
|
|
args: list of str, optional
|
|
The list of arguments used to initialized the rabit
|
|
usually you need to pass in sys.argv.
|
|
Defaults to sys.argv when it is None.
|
|
lib: {'standard', 'mock', 'mpi'}, optional
|
|
Type of library we want to load
|
|
When cdll is specified
|
|
lib_dll: ctypes.DLL, optional
|
|
The DLL object used as lib.
|
|
When this is presented argument lib will be ignored.
|
|
"""
|
|
if args is None:
|
|
args = []
|
|
_loadlib(lib, lib_dll)
|
|
arr = (ctypes.c_char_p * len(args))()
|
|
|
|
arr[:] = args
|
|
_LIB.RabitInit(len(args), arr)
|
|
|
|
def finalize():
|
|
"""Finalize the rabit engine.
|
|
|
|
Call this function after you finished all jobs.
|
|
"""
|
|
_LIB.RabitFinalize()
|
|
_unloadlib()
|
|
|
|
def get_rank():
|
|
"""Get rank of current process.
|
|
|
|
Returns
|
|
-------
|
|
rank : int
|
|
Rank of current process.
|
|
"""
|
|
ret = _LIB.RabitGetRank()
|
|
return ret
|
|
|
|
def get_world_size():
|
|
"""Get total number workers.
|
|
|
|
Returns
|
|
-------
|
|
n : int
|
|
Total number of process.
|
|
"""
|
|
ret = _LIB.RabitGetWorldSize()
|
|
return ret
|
|
|
|
def tracker_print(msg):
|
|
"""Print message to the tracker.
|
|
|
|
This function can be used to communicate the information of
|
|
the progress to the tracker
|
|
|
|
Parameters
|
|
----------
|
|
msg : str
|
|
The message to be printed to tracker.
|
|
"""
|
|
if not isinstance(msg, str):
|
|
msg = str(msg)
|
|
_LIB.RabitTrackerPrint(ctypes.c_char_p(msg).encode('utf-8'))
|
|
|
|
def get_processor_name():
|
|
"""Get the processor name.
|
|
|
|
Returns
|
|
-------
|
|
name : str
|
|
the name of processor(host)
|
|
"""
|
|
mxlen = 256
|
|
length = ctypes.c_ulong()
|
|
buf = ctypes.create_string_buffer(mxlen)
|
|
_LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen)
|
|
return buf.value
|
|
|
|
def broadcast(data, root):
|
|
"""Broadcast object from one node to all other nodes.
|
|
|
|
Parameters
|
|
----------
|
|
data : any type that can be pickled
|
|
Input data, if current rank does not equal root, this can be None
|
|
root : int
|
|
Rank of the node to broadcast data from.
|
|
|
|
Returns
|
|
-------
|
|
object : int
|
|
the result of broadcast.
|
|
"""
|
|
rank = get_rank()
|
|
length = ctypes.c_ulong()
|
|
if root == rank:
|
|
assert data is not None, 'need to pass in data when broadcasting'
|
|
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
|
length.value = len(s)
|
|
# run first broadcast
|
|
_LIB.RabitBroadcast(ctypes.byref(length),
|
|
ctypes.sizeof(ctypes.c_ulong), root)
|
|
if root != rank:
|
|
dptr = (ctypes.c_char * length.value)()
|
|
# run second
|
|
_LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
|
length.value, root)
|
|
data = pickle.loads(dptr.raw)
|
|
del dptr
|
|
else:
|
|
_LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
|
length.value, root)
|
|
del s
|
|
return data
|
|
|
|
# enumeration of dtypes
|
|
DTYPE_ENUM__ = {
|
|
np.dtype('int8') : 0,
|
|
np.dtype('uint8') : 1,
|
|
np.dtype('int32') : 2,
|
|
np.dtype('uint32') : 3,
|
|
np.dtype('int64') : 4,
|
|
np.dtype('uint64') : 5,
|
|
np.dtype('float32') : 6,
|
|
np.dtype('float64') : 7
|
|
}
|
|
|
|
def allreduce(data, op, prepare_fun=None):
|
|
"""Perform allreduce, return the result.
|
|
|
|
Parameters
|
|
----------
|
|
data: numpy array
|
|
Input data.
|
|
op: int
|
|
Reduction operators, can be MIN, MAX, SUM, BITOR
|
|
prepare_fun: function
|
|
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
|
will be called by the function before performing allreduce, to intialize the data
|
|
If the result of Allreduce can be recovered directly,
|
|
then prepare_fun will NOT be called
|
|
|
|
Returns
|
|
-------
|
|
result : array_like
|
|
The result of allreduce, have same shape as data
|
|
|
|
Notes
|
|
-----
|
|
This function is not thread-safe.
|
|
"""
|
|
if not isinstance(data, np.ndarray):
|
|
raise Exception('allreduce only takes in numpy.ndarray')
|
|
buf = data.ravel()
|
|
if buf.base is data.base:
|
|
buf = buf.copy()
|
|
if buf.dtype not in DTYPE_ENUM__:
|
|
raise Exception('data type %s not supported' % str(buf.dtype))
|
|
if prepare_fun is None:
|
|
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
|
op, None, None)
|
|
else:
|
|
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
|
def pfunc(args):
|
|
"""prepare function."""
|
|
prepare_fun(data)
|
|
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
|
op, func_ptr(pfunc), None)
|
|
return buf
|
|
|
|
|
|
def _load_model(ptr, length):
|
|
"""
|
|
Internal function used by the module,
|
|
unpickle a model from a buffer specified by ptr, length
|
|
Arguments:
|
|
ptr: ctypes.POINTER(ctypes._char)
|
|
pointer to the memory region of buffer
|
|
length: int
|
|
the length of buffer
|
|
"""
|
|
data = (ctypes.c_char * length).from_address(ctypes.addressof(ptr.contents))
|
|
return pickle.loads(data.raw)
|
|
|
|
def load_checkpoint(with_local=False):
|
|
"""Load latest check point.
|
|
|
|
Parameters
|
|
----------
|
|
with_local: bool, optional
|
|
whether the checkpoint contains local model
|
|
|
|
Returns
|
|
-------
|
|
tuple : tuple
|
|
if with_local: return (version, gobal_model, local_model)
|
|
else return (version, gobal_model)
|
|
if returned version == 0, this means no model has been CheckPointed
|
|
and global_model, local_model returned will be None
|
|
"""
|
|
gptr = ctypes.POINTER(ctypes.c_char)()
|
|
global_len = ctypes.c_ulong()
|
|
if with_local:
|
|
lptr = ctypes.POINTER(ctypes.c_char)()
|
|
local_len = ctypes.c_ulong()
|
|
version = _LIB.RabitLoadCheckPoint(
|
|
ctypes.byref(gptr),
|
|
ctypes.byref(global_len),
|
|
ctypes.byref(lptr),
|
|
ctypes.byref(local_len))
|
|
if version == 0:
|
|
return (version, None, None)
|
|
return (version,
|
|
_load_model(gptr, global_len.value),
|
|
_load_model(lptr, local_len.value))
|
|
else:
|
|
version = _LIB.RabitLoadCheckPoint(
|
|
ctypes.byref(gptr),
|
|
ctypes.byref(global_len),
|
|
None, None)
|
|
if version == 0:
|
|
return (version, None)
|
|
return (version,
|
|
_load_model(gptr, global_len.value))
|
|
|
|
def checkpoint(global_model, local_model=None):
|
|
"""Checkpoint the model.
|
|
|
|
This means we finished a stage of execution.
|
|
Every time we call check point, there is a version number which will increase by one.
|
|
|
|
Parameters
|
|
----------
|
|
global_model: anytype that can be pickled
|
|
globally shared model/state when calling this function,
|
|
the caller need to gauranttees that global_model is the same in all nodes
|
|
|
|
local_model: anytype that can be pickled
|
|
Local model, that is specific to current node/rank.
|
|
This can be None when no local state is needed.
|
|
|
|
Notes
|
|
-----
|
|
local_model requires explicit replication of the model for fault-tolerance.
|
|
This will bring replication cost in checkpoint function.
|
|
while global_model do not need explicit replication.
|
|
It is recommended to use global_model if possible.
|
|
"""
|
|
sglobal = pickle.dumps(global_model)
|
|
if local_model is None:
|
|
_LIB.RabitCheckPoint(sglobal, len(sglobal), None, 0)
|
|
del sglobal
|
|
else:
|
|
slocal = pickle.dumps(local_model)
|
|
_LIB.RabitCheckPoint(sglobal, len(sglobal), slocal, len(slocal))
|
|
del slocal
|
|
del sglobal
|
|
|
|
def version_number():
|
|
"""Returns version number of current stored model.
|
|
|
|
This means how many calls to CheckPoint we made so far.
|
|
|
|
Returns
|
|
-------
|
|
version : int
|
|
Version number of currently stored model
|
|
"""
|
|
ret = _LIB.RabitVersionNumber()
|
|
return ret
|