From ecb3a271bed151252fb048528ce5a90ad75bb68f Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 29 Feb 2016 10:00:37 -0800 Subject: [PATCH] [PYTHON-DIST] Distributed xgboost python training API. --- dmlc-core | 2 +- include/xgboost/c_api.h | 26 +++- include/xgboost/learner.h | 15 +++ python-package/xgboost/__init__.py | 1 + python-package/xgboost/compat.py | 2 + python-package/xgboost/core.py | 57 ++++++++- python-package/xgboost/rabit.py | 188 +++++++++++++++++++++++++++++ python-package/xgboost/training.py | 65 ++++++---- rabit | 2 +- src/c_api/c_api.cc | 32 ++++- src/data/data.cc | 2 +- src/learner.cc | 30 ++++- src/tree/updater_basemaker-inl.h | 3 + src/tree/updater_histmaker.cc | 1 + tests/distributed/runtests.sh | 4 + tests/distributed/test_basic.py | 29 +++++ 16 files changed, 427 insertions(+), 32 deletions(-) create mode 100644 python-package/xgboost/rabit.py create mode 100755 tests/distributed/runtests.sh create mode 100644 tests/distributed/test_basic.py diff --git a/dmlc-core b/dmlc-core index 38ee75d95..71360023d 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 38ee75d95ff23e4e1febacc89e08975d9b6c6c3a +Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0 diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 92a9efe2b..e950a2765 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -358,6 +358,30 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, bst_ulong *out_len, const char ***out_models); +/*! + * \brief Get string attribute from Booster. + * \param handle handle + * \param key The key of the attribute. + * \param out The result attribute, can be NULL if the attribute do not exist. + * \param success Whether the result is contained in out. + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterGetAttr(BoosterHandle handle, + const char* key, + const char** out, + int *success); +/*! + * \brief Set string attribute. + * + * \param handle handle + * \param key The key of the symbol. + * \param value The value to be saved. + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterSetAttr(BoosterHandle handle, + const char* key, + const char* value); + // --- Distributed training API---- // NOTE: functions in rabit/c_api.h will be also available in libxgboost.so /*! @@ -376,6 +400,6 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint( * \param handle handle * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterSaveRabitCheckPoint(BoosterHandle handle); +XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle); #endif // XGBOOST_C_API_H_ diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 9c8650312..18c782518 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -109,6 +109,21 @@ class Learner : public rabit::Serializable { std::vector *out_preds, unsigned ntree_limit = 0, bool pred_leaf = false) const = 0; + /*! + * \brief Set additional attribute to the Booster. + * The property will be saved along the booster. + * \param key The key of the property. + * \param value The value of the property. + */ + virtual void SetAttr(const std::string& key, const std::string& value) = 0; + /*! + * \brief Get attribute from the booster. + * The property will be saved along the booster. + * \param key The key of the attribute. + * \param out The output value. + * \return Whether the key is contained in the attribute. + */ + virtual bool GetAttr(const std::string& key, std::string* out) const = 0; /*! * \return whether the model allow lazy checkpoint in rabit. */ diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 1fe438289..304e72355 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -10,6 +10,7 @@ import os from .core import DMatrix, Booster from .training import train, cv +from . import rabit try: from .sklearn import XGBModel, XGBClassifier, XGBRegressor from .plotting import plot_importance, plot_tree, to_graphviz diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index e94c67b84..81234df4a 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -12,9 +12,11 @@ PY3 = (sys.version_info[0] == 3) if PY3: # pylint: disable=invalid-name, redefined-builtin STRING_TYPES = str, + py_str = lambda x: x.decode('utf-8') else: # pylint: disable=invalid-name STRING_TYPES = basestring, + py_str = lambda x: x # pandas try: diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index c2840bff5..0b6949cf1 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -12,7 +12,7 @@ import scipy.sparse from .libpath import find_lib_path -from .compat import STRING_TYPES, PY3, DataFrame +from .compat import STRING_TYPES, PY3, DataFrame, py_str class XGBoostError(Exception): """Error throwed by xgboost trainer.""" @@ -654,10 +654,63 @@ class Booster(object): Returns ------- booster: `Booster` - a copied booster model + a copied booster model """ return self.__copy__() + def load_rabit_checkpoint(self): + """Initialize the model by load from rabit checkpoint. + + Returns + ------- + version: integer + The version number of the model. + """ + version = ctypes.c_int() + _check_call(_LIB.XGBoosterLoadRabitCheckpoint( + self.handle, ctypes.byref(version))) + return version.value + + def save_rabit_checkpoint(self): + """Save the current booster to rabit checkpoint.""" + _check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle)) + + def attr(self, key): + """Get attribute string from the Booster. + + Parameters + ---------- + key : str + The key to get attribute from. + + Returns + ------- + value : str + The attribute value of the key, returns None if attribute do not exist. + """ + ret = ctypes.c_char_p() + success = ctypes.c_int() + _check_call(_LIB.XGBoosterGetAttr( + self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) + if success.value != 0: + return py_str(ret.value) + else: + return None + + def set_attr(self, **kwargs): + """Set the attribute of the Booster. + + Parameters + ---------- + **kwargs + The attributes to set + """ + for key, value in kwargs.items(): + if not isinstance(value, STRING_TYPES): + raise ValueError("Set Attr only accepts string values") + _check_call(_LIB.XGBoosterSetAttr( + self.handle, c_str(key), c_str(str(value)))) + def set_param(self, params, value=None): """Set parameters into the Booster. diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py new file mode 100644 index 000000000..da4cfa675 --- /dev/null +++ b/python-package/xgboost/rabit.py @@ -0,0 +1,188 @@ +"""Distributed XGBoost Rabit related API.""" +from __future__ import absolute_import +import sys +import atexit +import ctypes +import numpy as np + +from .core import _LIB, c_str, STRING_TYPES + +def _init_rabit(): + """Initialize the rabit library.""" + _LIB.RabitGetRank.restype = ctypes.c_int + _LIB.RabitGetWorldSize.restype = ctypes.c_int + _LIB.RabitVersionNumber.restype = ctypes.c_int + _LIB.RabitInit(0, None) + + +def finalize(): + """Finalize the process, notify tracker everything is done.""" + _LIB.RabitFinalize() + + +def get_rank(): + """Get rank of current process. + + Returns + ------- + rank : int + Rank of current process. + """ + ret = _LIB.RabitGetRank() + return ret + + +def get_world_size(): + """Get total number workers. + + Returns + ------- + n : int + Total number of process. + """ + ret = _LIB.RabitGetWorldSize() + return ret + + +def tracker_print(msg): + """Print message to the tracker. + + This function can be used to communicate the information of + the progress to the tracker + + Parameters + ---------- + msg : str + The message to be printed to tracker. + """ + if not isinstance(msg, STRING_TYPES): + msg = str(msg) + _LIB.RabitTrackerPrint(c_str(msg)) + + +def get_processor_name(): + """Get the processor name. + + Returns + ------- + name : str + the name of processor(host) + """ + mxlen = 256 + length = ctypes.c_ulong() + buf = ctypes.create_string_buffer(mxlen) + _LIB.RabitGetProcessorName(buf, ctypes.byref(length), mxlen) + return buf.value + + +def broadcast(data, root): + """Broadcast object from one node to all other nodes. + + Parameters + ---------- + data : any type that can be pickled + Input data, if current rank does not equal root, this can be None + root : int + Rank of the node to broadcast data from. + + Returns + ------- + object : int + the result of broadcast. + """ + rank = get_rank() + length = ctypes.c_ulong() + if root == rank: + assert data is not None, 'need to pass in data when broadcasting' + s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + length.value = len(s) + # run first broadcast + _LIB.RabitBroadcast(ctypes.byref(length), + ctypes.sizeof(ctypes.c_ulong), root) + if root != rank: + dptr = (ctypes.c_char * length.value)() + # run second + _LIB.RabitBroadcast(ctypes.cast(dptr, ctypes.c_void_p), + length.value, root) + data = pickle.loads(dptr.raw) + del dptr + else: + _LIB.RabitBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), + length.value, root) + del s + return data + +# enumeration of dtypes +DTYPE_ENUM__ = { + np.dtype('int8') : 0, + np.dtype('uint8') : 1, + np.dtype('int32') : 2, + np.dtype('uint32') : 3, + np.dtype('int64') : 4, + np.dtype('uint64') : 5, + np.dtype('float32') : 6, + np.dtype('float64') : 7 +} + + +def allreduce(data, op, prepare_fun=None): + """Perform allreduce, return the result. + + Parameters + ---------- + data: numpy array + Input data. + op: int + Reduction operators, can be MIN, MAX, SUM, BITOR + prepare_fun: function + Lazy preprocessing function, if it is not None, prepare_fun(data) + will be called by the function before performing allreduce, to intialize the data + If the result of Allreduce can be recovered directly, + then prepare_fun will NOT be called + + Returns + ------- + result : array_like + The result of allreduce, have same shape as data + + Notes + ----- + This function is not thread-safe. + """ + if not isinstance(data, np.ndarray): + raise Exception('allreduce only takes in numpy.ndarray') + buf = data.ravel() + if buf.base is data.base: + buf = buf.copy() + if buf.dtype not in DTYPE_ENUM__: + raise Exception('data type %s not supported' % str(buf.dtype)) + if prepare_fun is None: + _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), + buf.size, DTYPE_ENUM__[buf.dtype], + op, None, None) + else: + func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + def pfunc(args): + """prepare function.""" + prepare_fun(data) + _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), + buf.size, DTYPE_ENUM__[buf.dtype], + op, func_ptr(pfunc), None) + return buf + + +def version_number(): + """Returns version number of current stored model. + + This means how many calls to CheckPoint we made so far. + + Returns + ------- + version : int + Version number of currently stored model + """ + ret = _LIB.RabitVersionNumber() + return ret + +# intialization script +_init_rabit() diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 6e02e8f10..d84c030f9 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -6,9 +6,11 @@ from __future__ import absolute_import import sys import re +import os import numpy as np from .core import Booster, STRING_TYPES from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold, XGBKFold) +from . import rabit def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None, evals_result=None, @@ -94,6 +96,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, verbose_eval_every_line = verbose_eval verbose_eval = True if verbose_eval_every_line > 0 else False + if rabit.get_rank() != 0: + verbose_eval = False; + if xgb_model is not None: if not isinstance(xgb_model, STRING_TYPES): xgb_model = xgb_model.save_raw() @@ -123,15 +128,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, raise ValueError('For early stopping you need at least one set in evals.') if verbose_eval: - sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format( + rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format( evals[-1][1], early_stopping_rounds)) # is params a list of tuples? are we using multiple eval metrics? if isinstance(params, list): if len(params) != len(dict(params).items()): params = dict(params) - sys.stderr.write("Multiple eval metrics have been passed: " \ - "'{0}' will be used for early stopping.\n\n".format(params['eval_metric'])) + rabit.tracker_print("Multiple eval metrics have been passed: " \ + "'{0}' will be used for early stopping.\n\n".format(params['eval_metric'])) else: params = dict(params) @@ -145,23 +150,35 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, maximize_score = maximize if maximize_score: - best_score = 0.0 + bst.set_attr(best_score='0.0') else: - best_score = float('inf') - - best_msg = '' - best_score_i = (nboost - 1) + bst.set_attr(best_score='inf') + bst.set_attr(best_iteration='0') if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round: raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") - for i in range(nboost, nboost + num_boost_round): + # Distributed code: Load the checkpoint from rabit. + version = bst.load_rabit_checkpoint() + assert(rabit.get_world_size() != 1 or version == 0) + start_iteration = int(version / 2) + nboost += start_iteration + + for i in range(start_iteration, num_boost_round): if learning_rates is not None: if isinstance(learning_rates, list): bst.set_param({'eta': learning_rates[i]}) else: bst.set_param({'eta': learning_rates(i, num_boost_round)}) - bst.update(dtrain, i, obj) + + # Distributed code: need to resume to this point. + # Skip the first update if it is a recovery step. + if version % 2 == 0: + bst.update(dtrain, i, obj) + bst.save_rabit_checkpoint() + version += 1 + + assert(rabit.get_world_size() == 1 or version == rabit.version_number()) nboost += 1 # check evaluation result. @@ -176,9 +193,9 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, if verbose_eval: if verbose_eval_every_line: if i % verbose_eval_every_line == 0 or i == num_boost_round - 1: - sys.stderr.write(msg + '\n') + rabit.tracker_print(msg + '\n') else: - sys.stderr.write(msg + '\n') + rabit.tracker_print(msg + '\n') if evals_result is not None: res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg) @@ -196,22 +213,26 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, if early_stopping_rounds: score = float(msg.rsplit(':', 1)[1]) + best_score = float(bst.attr('best_score')) + best_iteration = int(bst.attr('best_iteration')) if (maximize_score and score > best_score) or \ (not maximize_score and score < best_score): - best_score = score - best_score_i = (nboost - 1) - best_msg = msg - elif i - best_score_i >= early_stopping_rounds: + # save the property to attributes, so they will occur in checkpoint. + bst.set_attr(best_score=str(score), + best_iteration=str(nboost - 1), + best_msg=msg) + elif i - best_iteration >= early_stopping_rounds: + best_msg = bst.attr('best_msg') if verbose_eval: - sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg)) - # best iteration will be assigned in the end. - bst.best_score = best_score - bst.best_iteration = best_score_i + rabit.tracker_print("Stopping. Best iteration:\n{}\n\n".format(best_msg)) break + # do checkpoint after evaluation, in case evaluation also updates booster. + bst.save_rabit_checkpoint() + version += 1 if early_stopping_rounds: - best_score = best_score - bst.best_iteration = best_score_i + bst.best_score = float(bst.attr('best_score')) + bst.best_iteration = int(bst.attr('best_iteration')) else: bst.best_iteration = nboost - 1 bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree diff --git a/rabit b/rabit index 56ec4263f..1392e9f3d 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 56ec4263f9a70a315c1f153dc5897b7c1b58250c +Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0 diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f9e5375ec..b543f9b6e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -91,7 +91,7 @@ int XGDMatrixCreateFromFile(const char *fname, << "will split data among workers"; } *out = DMatrix::Load( - fname, silent != 0, false); + fname, false, true); API_END(); } @@ -533,18 +533,44 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle, API_END(); } +int XGBoosterGetAttr(BoosterHandle handle, + const char* key, + const char** out, + int* success) { + Booster* bst = static_cast(handle); + std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str; + API_BEGIN(); + if (bst->learner()->GetAttr(key, &ret_str)) { + *out = ret_str.c_str(); + *success = 1; + } else { + *out = nullptr; + *success = 0; + } + API_END(); +} + +int XGBoosterSetAttr(BoosterHandle handle, + const char* key, + const char* value) { + Booster* bst = static_cast(handle); + API_BEGIN(); + bst->learner()->SetAttr(key, value); + API_END(); +} + int XGBoosterLoadRabitCheckpoint(BoosterHandle handle, int* version) { API_BEGIN(); Booster* bst = static_cast(handle); *version = rabit::LoadCheckPoint(bst->learner()); - if (version != 0) { + if (*version != 0) { bst->initialized_ = true; } API_END(); } -int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) { +int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) { API_BEGIN(); Booster* bst = static_cast(handle); if (bst->learner()->AllowLazyCheckPoint()) { diff --git a/src/data/data.cc b/src/data/data.cc index 65efa0f8f..e8135692a 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -184,7 +184,7 @@ DMatrix* DMatrix::Load(const std::string& uri, << " of " << npart << " parts"; } // legacy handling of binary data loading - if (file_format == "auto" && !load_row_split) { + if (file_format == "auto" && npart == 1) { int magic; std::unique_ptr fi(dmlc::Stream::Create(fname.c_str(), "r", true)); if (fi.get() != nullptr) { diff --git a/src/learner.cc b/src/learner.cc index f787be9fc..59f7c53f7 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -5,6 +5,7 @@ * \author Tianqi Chen */ #include +#include #include #include #include @@ -43,8 +44,10 @@ struct LearnerModelParam unsigned num_feature; /* \brief number of classes, if it is multi-class classification */ int num_class; + /*! \brief Model contain additional properties */ + int contain_extra_attrs; /*! \brief reserved field */ - int reserved[31]; + int reserved[30]; /*! \brief constructor */ LearnerModelParam() { std::memset(this, 0, sizeof(LearnerModelParam)); @@ -243,6 +246,12 @@ class LearnerImpl : public Learner { obj_.reset(ObjFunction::Create(name_obj_)); gbm_.reset(GradientBooster::Create(name_gbm_)); gbm_->Load(fi); + if (mparam.contain_extra_attrs != 0) { + std::vector > attr; + fi->Read(&attr); + attributes_ = std::map( + attr.begin(), attr.end()); + } if (metrics_.size() == 0) { metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric())); @@ -259,6 +268,11 @@ class LearnerImpl : public Learner { fo->Write(name_obj_); fo->Write(name_gbm_); gbm_->Save(fo); + if (mparam.contain_extra_attrs != 0) { + std::vector > attr( + attributes_.begin(), attributes_.end()); + fo->Write(attr); + } } void UpdateOneIter(int iter, DMatrix* train) override { @@ -300,6 +314,18 @@ class LearnerImpl : public Learner { return os.str(); } + void SetAttr(const std::string& key, const std::string& value) override { + attributes_[key] = value; + mparam.contain_extra_attrs = 1; + } + + bool GetAttr(const std::string& key, std::string* out) const override { + auto it = attributes_.find(key); + if (it == attributes_.end()) return false; + *out = it->second; + return true; + } + std::pair Evaluate(DMatrix* data, std::string metric) { if (metric == "auto") metric = obj_->DefaultEvalMetric(); std::unique_ptr ev(Metric::Create(metric.c_str())); @@ -427,6 +453,8 @@ class LearnerImpl : public Learner { LearnerTrainParam tparam; // configurations std::map cfg_; + // attributes + std::map attributes_; // name of gbm std::string name_gbm_; // name of objective functon diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index b6dbacd6c..2b4b170a6 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -56,6 +56,9 @@ class BaseMaker: public TreeUpdater { } } } + } + /*! \brief synchronize the information */ + inline void SyncInfo() { rabit::Allreduce(dmlc::BeginPtr(fminmax), fminmax.size()); } // get feature type, 0:empty 1:binary 2:real diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index bf3f2571e..98c75e935 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -313,6 +313,7 @@ class CQHistMaker: public HistMaker { feat_helper.InitByCol(p_fmat, tree); cache_dmatrix_ = p_fmat; } + feat_helper.SyncInfo(); feat_helper.SampleCol(this->param.colsample_bytree, p_fset); } // code to create histogram diff --git a/tests/distributed/runtests.sh b/tests/distributed/runtests.sh new file mode 100755 index 000000000..997fb1893 --- /dev/null +++ b/tests/distributed/runtests.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +PYTHONPATH=../../python-package/ ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3\ + python test_basic.py diff --git a/tests/distributed/test_basic.py b/tests/distributed/test_basic.py new file mode 100644 index 000000000..20504fd13 --- /dev/null +++ b/tests/distributed/test_basic.py @@ -0,0 +1,29 @@ +#!/usr/bin/python +import numpy as np +import scipy.sparse +import pickle +import xgboost as xgb + +# Load file, file will be automatically sharded in distributed mode. +dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') +dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') + +# specify parameters via map, definition are same as c++ version +param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' } + +# specify validations set to watch performance +watchlist = [(dtest,'eval'), (dtrain,'train')] +num_round = 20 + +# Run training, all the features in training API is available. +# Currently, this script only support calling train once for fault recovery purpose. +bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) + +# save the model, only ask process 0 to save the model. +if xgb.rabit.get_rank() == 0: + bst.save_model("test.model") + xgb.rabit.tracker_print("Finished training\n") + +# Notify the tracker all training has been successful +# This is only needed in distributed training. +xgb.rabit.finalize()