[Breaking] Fix custom metric for multi output. (#5954)

* Set output margin to true for custom metric.  This fixes only R and Python.
This commit is contained in:
Jiaming Yuan 2020-07-29 19:25:27 +08:00 committed by GitHub
parent 75b8c22b0b
commit 18349a7ccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 41 additions and 13 deletions

View File

@ -173,7 +173,7 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
} else { } else {
res <- sapply(seq_along(watchlist), function(j) { res <- sapply(seq_along(watchlist), function(j) {
w <- watchlist[[j]] w <- watchlist[[j]]
preds <- predict(booster_handle, w, ntreelimit = 0) # predict using all trees preds <- predict(booster_handle, w, outputmargin = TRUE, ntreelimit = 0) # predict using all trees
eval_res <- feval(preds, w) eval_res <- feval(preds, w)
out <- eval_res$value out <- eval_res$value
names(out) <- paste0(evnames[j], "-", eval_res$metric) names(out) <- paste0(evnames[j], "-", eval_res$metric)

View File

@ -79,6 +79,10 @@ test_that("custom objective with multi-class works", {
hess <- rnorm(dim(as.matrix(preds))[1]) hess <- rnorm(dim(as.matrix(preds))[1])
return (list(grad = grad, hess = hess)) return (list(grad = grad, hess = hess))
} }
fake_merror <- function(preds, dtrain) {
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
}
param$objective <- fake_softprob param$objective <- fake_softprob
param$eval_metric <- fake_merror
bst <- xgb.train(param, dtrain, 1, num_class = nclasses) bst <- xgb.train(param, dtrain, 1, num_class = nclasses)
}) })

View File

@ -75,7 +75,7 @@ def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
return grad, hess return grad, hess
def predict(booster, X): def predict(booster: xgb.Booster, X):
'''A customized prediction function that converts raw prediction to '''A customized prediction function that converts raw prediction to
target class. target class.
@ -93,15 +93,34 @@ def predict(booster, X):
return out return out
def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
y = dtrain.get_label()
# Like custom objective, the predt is untransformed leaf weight
assert predt.shape == (kRows, kClasses)
out = np.zeros(kRows)
for r in range(predt.shape[0]):
i = np.argmax(predt[r])
out[r] = i
assert y.shape == out.shape
errors = np.zeros(kRows)
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / kRows
def plot_history(custom_results, native_results): def plot_history(custom_results, native_results):
fig, axs = plt.subplots(2, 1) fig, axs = plt.subplots(2, 1)
ax0 = axs[0] ax0 = axs[0]
ax1 = axs[1] ax1 = axs[1]
pymerror = custom_results['train']['PyMError']
merror = native_results['train']['merror']
x = np.arange(0, kRounds, 1) x = np.arange(0, kRounds, 1)
ax0.plot(x, custom_results['train']['merror'], label='Custom objective') ax0.plot(x, pymerror, label='Custom objective')
ax0.legend() ax0.legend()
ax1.plot(x, native_results['train']['merror'], label='multi:softmax') ax1.plot(x, merror, label='multi:softmax')
ax1.legend() ax1.legend()
plt.show() plt.show()
@ -110,10 +129,12 @@ def plot_history(custom_results, native_results):
def main(args): def main(args):
custom_results = {} custom_results = {}
# Use our custom objective function # Use our custom objective function
booster_custom = xgb.train({'num_class': kClasses}, booster_custom = xgb.train({'num_class': kClasses,
'disable_default_eval_metric': True},
m, m,
num_boost_round=kRounds, num_boost_round=kRounds,
obj=softprob_obj, obj=softprob_obj,
feval=merror,
evals_result=custom_results, evals_result=custom_results,
evals=[(m, 'train')]) evals=[(m, 'train')])
@ -131,6 +152,8 @@ def main(args):
# We are reimplementing the loss function in XGBoost, so it should # We are reimplementing the loss function in XGBoost, so it should
# be the same for normal cases. # be the same for normal cases.
assert np.all(predt_custom == predt_native) assert np.all(predt_custom == predt_native)
np.testing.assert_allclose(custom_results['train']['PyMError'],
native_results['train']['merror'])
if args.plot != 0: if args.plot != 0:
plot_history(custom_results, native_results) plot_history(custom_results, native_results)

View File

@ -40,9 +40,9 @@ General Parameters
- Number of parallel threads used to run XGBoost - Number of parallel threads used to run XGBoost
* ``disable_default_eval_metric`` [default=0] * ``disable_default_eval_metric`` [default=``false``]
- Flag to disable default metric. Set to >0 to disable. - Flag to disable default metric. Set to 1 or ``true`` to disable.
* ``num_pbuffer`` [set automatically by XGBoost, no need to be set by user] * ``num_pbuffer`` [set automatically by XGBoost, no need to be set by user]

View File

@ -1228,7 +1228,8 @@ class Booster(object):
res = msg.value.decode() res = msg.value.decode()
if feval is not None: if feval is not None:
for dmat, evname in evals: for dmat, evname in evals:
feval_ret = feval(self.predict(dmat, training=False), dmat) feval_ret = feval(self.predict(dmat, training=False,
output_margin=True), dmat)
if isinstance(feval_ret, list): if isinstance(feval_ret, list):
for name, val in feval_ret: for name, val in feval_ret:
res += '\t%s-%s:%f' % (evname, name, val) res += '\t%s-%s:%f' % (evname, name, val)

View File

@ -154,9 +154,9 @@ LearnerModelParam::LearnerModelParam(
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> { struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none. // data split mode, can be row, col, or none.
DataSplitMode dsplit; DataSplitMode dsplit {DataSplitMode::kAuto};
// flag to disable default metric // flag to disable default metric
int disable_default_eval_metric; bool disable_default_eval_metric {false};
// FIXME(trivialfis): The following parameters belong to model itself, but can be // FIXME(trivialfis): The following parameters belong to model itself, but can be
// specified by users. Move them to model parameter once we can get rid of binary IO. // specified by users. Move them to model parameter once we can get rid of binary IO.
std::string booster; std::string booster;
@ -171,7 +171,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
.add_enum("row", DataSplitMode::kRow) .add_enum("row", DataSplitMode::kRow)
.describe("Data split mode for distributed training."); .describe("Data split mode for distributed training.");
DMLC_DECLARE_FIELD(disable_default_eval_metric) DMLC_DECLARE_FIELD(disable_default_eval_metric)
.set_default(0) .set_default(false)
.describe("Flag to disable default metric. Set to >0 to disable"); .describe("Flag to disable default metric. Set to >0 to disable");
DMLC_DECLARE_FIELD(booster) DMLC_DECLARE_FIELD(booster)
.set_default("gbtree") .set_default("gbtree")
@ -253,7 +253,7 @@ class LearnerConfiguration : public Learner {
void Configure() override { void Configure() override {
// Varient of double checked lock // Varient of double checked lock
if (!this->need_configuration_) { return; } if (!this->need_configuration_) { return; }
std::lock_guard<std::mutex> gard(config_lock_); std::lock_guard<std::mutex> guard(config_lock_);
if (!this->need_configuration_) { return; } if (!this->need_configuration_) { return; }
monitor_.Start("Configure"); monitor_.Start("Configure");

View File

@ -37,11 +37,11 @@ class TestEarlyStopping(unittest.TestCase):
eval_set=[(X_test, y_test)]) eval_set=[(X_test, y_test)])
assert clf3.best_score == 1 assert clf3.best_score == 1
@pytest.mark.skipif(**tm.no_sklearn())
def evalerror(self, preds, dtrain): def evalerror(self, preds, dtrain):
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
labels = dtrain.get_label() labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
return 'rmse', mean_squared_error(labels, preds) return 'rmse', mean_squared_error(labels, preds)
@staticmethod @staticmethod