[PYTHON] Simplify training logic, update rabit lib
This commit is contained in:
parent
90bc7f8f6b
commit
4a16b729fc
2
Makefile
2
Makefile
@ -118,7 +118,7 @@ lib/libxgboost.a: $(ALL_DEP)
|
||||
|
||||
lib/libxgboost.dll lib/libxgboost.so: $(ALL_DEP)
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
|
||||
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %a, $^) $(LDFLAGS)
|
||||
|
||||
java/libxgboost4j.so: java/xgboost4j_wrapper.cpp $(ALL_DEP)
|
||||
$(CXX) $(CFLAGS) $(JAVAINCFLAGS) -shared -o $@ $(filter %.cpp %.o %.a, $^) $(LDFLAGS)
|
||||
|
||||
@ -14,4 +14,5 @@ PKG_CPPFLAGS= -I$(PKGROOT)/include -I$(PKGROOT)/dmlc-core/include -I$(PKGROOT)/
|
||||
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS)
|
||||
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS)
|
||||
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o\
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o $(PKGROOT)/rabit/src/engine_empty.o
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o\
|
||||
$(PKGROOT)/rabit/src/engine_empty.o $(PKGROOT)/rabit/src/c_api.o
|
||||
|
||||
@ -26,6 +26,7 @@ PKG_CPPFLAGS= -I$(PKGROOT)/include -I$(PKGROOT)/dmlc-core/include -I$(PKGROOT)/
|
||||
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS)
|
||||
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS)
|
||||
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o\
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o $(PKGROOT)/rabit/src/engine_empty.o
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o\
|
||||
$(PKGROOT)/rabit/src/engine_empty.o $(PKGROOT)/rabit/src/c_api.o
|
||||
|
||||
$(OBJECTS) : xgblib
|
||||
|
||||
@ -11,6 +11,9 @@
|
||||
#define XGB_EXTERN_C extern "C"
|
||||
#endif
|
||||
|
||||
// XGBoost C API will include APIs in Rabit C API
|
||||
#include <rabit/c_api.h>
|
||||
|
||||
#if defined(_MSC_VER) || defined(_WIN32)
|
||||
#define XGB_DLL XGB_EXTERN_C __declspec(dllexport)
|
||||
#else
|
||||
@ -221,6 +224,7 @@ XGB_DLL int XGBoosterFree(BoosterHandle handle);
|
||||
XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
||||
const char *name,
|
||||
const char *value);
|
||||
|
||||
/*!
|
||||
* \brief update the model in one round using dtrain
|
||||
* \param handle handle
|
||||
@ -282,6 +286,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
unsigned ntree_limit,
|
||||
bst_ulong *out_len,
|
||||
const float **out_result);
|
||||
|
||||
/*!
|
||||
* \brief load model from existing file
|
||||
* \param handle handle
|
||||
@ -353,4 +358,24 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
bst_ulong *out_len,
|
||||
const char ***out_models);
|
||||
|
||||
// --- Distributed training API----
|
||||
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
|
||||
/*!
|
||||
* \brief Initialize the booster from rabit checkpoint.
|
||||
* This is used in distributed training API.
|
||||
* \param handle handle
|
||||
* \param version The output version of the model.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(
|
||||
BoosterHandle handle,
|
||||
int* version);
|
||||
|
||||
/*!
|
||||
* \brief Save the current checkpoint to rabit.
|
||||
* \param handle handle
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSaveRabitCheckPoint(BoosterHandle handle);
|
||||
|
||||
#endif // XGBOOST_C_API_H_
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
#ifndef XGBOOST_LEARNER_H_
|
||||
#define XGBOOST_LEARNER_H_
|
||||
|
||||
#include <rabit.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -24,7 +24,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
Data to be trained.
|
||||
num_boost_round: int
|
||||
Number of boosting iterations.
|
||||
watchlist (evals): list of pairs (DMatrix, string)
|
||||
evals: list of pairs (DMatrix, string)
|
||||
List of items to be evaluated during training, this allows user to watch
|
||||
performance on the validation set.
|
||||
obj : function
|
||||
@ -117,48 +117,13 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
evals_result.clear()
|
||||
evals_result.update(dict([(key, {}) for key in evals_name]))
|
||||
|
||||
if not early_stopping_rounds:
|
||||
for i in range(nboost, nboost + num_boost_round):
|
||||
bst.update(dtrain, i, obj)
|
||||
nboost += 1
|
||||
if len(evals) != 0:
|
||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||
if isinstance(bst_eval_set, STRING_TYPES):
|
||||
msg = bst_eval_set
|
||||
else:
|
||||
msg = bst_eval_set.decode()
|
||||
|
||||
if verbose_eval:
|
||||
if verbose_eval_every_line:
|
||||
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
|
||||
sys.stderr.write(msg + '\n')
|
||||
else:
|
||||
sys.stderr.write(msg + '\n')
|
||||
|
||||
if evals_result is not None:
|
||||
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
||||
for key in evals_name:
|
||||
evals_idx = evals_name.index(key)
|
||||
res_per_eval = len(res) // len(evals_name)
|
||||
for r in range(res_per_eval):
|
||||
res_item = res[(evals_idx*res_per_eval) + r]
|
||||
res_key = res_item[0]
|
||||
res_val = res_item[1]
|
||||
if res_key in evals_result[key]:
|
||||
evals_result[key][res_key].append(res_val)
|
||||
else:
|
||||
evals_result[key][res_key] = [res_val]
|
||||
bst.best_iteration = (nboost - 1)
|
||||
bst.best_ntree_limit = nboost * num_parallel_tree
|
||||
return bst
|
||||
|
||||
else:
|
||||
# early stopping
|
||||
# early stopping
|
||||
if early_stopping_rounds is not None:
|
||||
if len(evals) < 1:
|
||||
raise ValueError('For early stopping you need at least one set in evals.')
|
||||
|
||||
if verbose_eval:
|
||||
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(\
|
||||
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
||||
evals[-1][1], early_stopping_rounds))
|
||||
|
||||
# is params a list of tuples? are we using multiple eval metrics?
|
||||
@ -166,7 +131,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
if len(params) != len(dict(params).items()):
|
||||
params = dict(params)
|
||||
sys.stderr.write("Multiple eval metrics have been passed: " \
|
||||
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
|
||||
"'{0}' will be used for early stopping.\n\n".format(params['eval_metric']))
|
||||
else:
|
||||
params = dict(params)
|
||||
|
||||
@ -184,20 +149,23 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
else:
|
||||
best_score = float('inf')
|
||||
|
||||
best_msg = ''
|
||||
best_score_i = (nboost - 1)
|
||||
best_msg = ''
|
||||
best_score_i = (nboost - 1)
|
||||
|
||||
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||
|
||||
for i in range(nboost, nboost + num_boost_round):
|
||||
if learning_rates is not None:
|
||||
if isinstance(learning_rates, list):
|
||||
bst.set_param({'eta': learning_rates[i]})
|
||||
else:
|
||||
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
||||
bst.update(dtrain, i, obj)
|
||||
nboost += 1
|
||||
for i in range(nboost, nboost + num_boost_round):
|
||||
if learning_rates is not None:
|
||||
if isinstance(learning_rates, list):
|
||||
bst.set_param({'eta': learning_rates[i]})
|
||||
else:
|
||||
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
||||
bst.update(dtrain, i, obj)
|
||||
|
||||
nboost += 1
|
||||
# check evaluation result.
|
||||
if len(evals) != 0:
|
||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||
|
||||
if isinstance(bst_eval_set, STRING_TYPES):
|
||||
@ -226,22 +194,28 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
else:
|
||||
evals_result[key][res_key] = [res_val]
|
||||
|
||||
score = float(msg.rsplit(':', 1)[1])
|
||||
if (maximize_score and score > best_score) or \
|
||||
(not maximize_score and score < best_score):
|
||||
best_score = score
|
||||
best_score_i = (nboost - 1)
|
||||
best_msg = msg
|
||||
elif i - best_score_i >= early_stopping_rounds:
|
||||
if verbose_eval:
|
||||
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||
bst.best_score = best_score
|
||||
bst.best_iteration = best_score_i
|
||||
break
|
||||
bst.best_score = best_score
|
||||
if early_stopping_rounds:
|
||||
score = float(msg.rsplit(':', 1)[1])
|
||||
if (maximize_score and score > best_score) or \
|
||||
(not maximize_score and score < best_score):
|
||||
best_score = score
|
||||
best_score_i = (nboost - 1)
|
||||
best_msg = msg
|
||||
elif i - best_score_i >= early_stopping_rounds:
|
||||
if verbose_eval:
|
||||
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||
# best iteration will be assigned in the end.
|
||||
bst.best_score = best_score
|
||||
bst.best_iteration = best_score_i
|
||||
break
|
||||
|
||||
if early_stopping_rounds:
|
||||
best_score = best_score
|
||||
bst.best_iteration = best_score_i
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
return bst
|
||||
else:
|
||||
bst.best_iteration = nboost - 1
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
return bst
|
||||
|
||||
|
||||
class CVPack(object):
|
||||
@ -299,7 +273,7 @@ def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0):
|
||||
# pylint: disable=invalid-name
|
||||
"""
|
||||
Aggregate cross-validation results.
|
||||
|
||||
|
||||
If verbose_eval is true, progress is displayed in every call. If
|
||||
verbose_eval is an integer, progress will only be displayed every
|
||||
`verbose_eval` trees, tracked via trial.
|
||||
@ -486,4 +460,3 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
results = np.array(results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit 112d866dc92354304c0891500374fe40cdf13a50
|
||||
Subproject commit 56ec4263f9a70a315c1f153dc5897b7c1b58250c
|
||||
@ -3,6 +3,8 @@
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
@ -84,6 +86,10 @@ int XGDMatrixCreateFromFile(const char *fname,
|
||||
int silent,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
if (rabit::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||
<< "will split data among workers";
|
||||
}
|
||||
*out = DMatrix::Load(
|
||||
fname, silent != 0, false);
|
||||
API_END();
|
||||
@ -526,3 +532,28 @@ int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst->learner());
|
||||
if (version != 0) {
|
||||
bst->initialized_ = true;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterSaveRabitCheckPoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
if (bst->learner()->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(bst->learner());
|
||||
} else {
|
||||
rabit::CheckPoint(bst->learner());
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
@ -8,6 +8,6 @@
|
||||
#ifndef XGBOOST_COMMON_SYNC_H_
|
||||
#define XGBOOST_COMMON_SYNC_H_
|
||||
|
||||
#include <rabit.h>
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#endif // XGBOOST_COMMON_SYNC_H_
|
||||
|
||||
@ -31,7 +31,7 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
# learning_rates as a customized decay function
|
||||
def eta_decay(ithround, num_boost_round):
|
||||
return num_boost_round / ithround
|
||||
return num_boost_round / (ithround + 1)
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=eta_decay)
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
|
||||
@ -49,7 +49,7 @@ class TestModels(unittest.TestCase):
|
||||
def evalerror(preds, dtrain):
|
||||
labels = dtrain.get_label()
|
||||
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
|
||||
|
||||
|
||||
# test custom_objective in training
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user