[Breaking] Fix custom metric for multi output. (#5954)
* Set output margin to true for custom metric. This fixes only R and Python.
This commit is contained in:
parent
75b8c22b0b
commit
18349a7ccf
@ -173,7 +173,7 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
|
|||||||
} else {
|
} else {
|
||||||
res <- sapply(seq_along(watchlist), function(j) {
|
res <- sapply(seq_along(watchlist), function(j) {
|
||||||
w <- watchlist[[j]]
|
w <- watchlist[[j]]
|
||||||
preds <- predict(booster_handle, w, ntreelimit = 0) # predict using all trees
|
preds <- predict(booster_handle, w, outputmargin = TRUE, ntreelimit = 0) # predict using all trees
|
||||||
eval_res <- feval(preds, w)
|
eval_res <- feval(preds, w)
|
||||||
out <- eval_res$value
|
out <- eval_res$value
|
||||||
names(out) <- paste0(evnames[j], "-", eval_res$metric)
|
names(out) <- paste0(evnames[j], "-", eval_res$metric)
|
||||||
|
|||||||
@ -79,6 +79,10 @@ test_that("custom objective with multi-class works", {
|
|||||||
hess <- rnorm(dim(as.matrix(preds))[1])
|
hess <- rnorm(dim(as.matrix(preds))[1])
|
||||||
return (list(grad = grad, hess = hess))
|
return (list(grad = grad, hess = hess))
|
||||||
}
|
}
|
||||||
|
fake_merror <- function(preds, dtrain) {
|
||||||
|
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
|
||||||
|
}
|
||||||
param$objective <- fake_softprob
|
param$objective <- fake_softprob
|
||||||
|
param$eval_metric <- fake_merror
|
||||||
bst <- xgb.train(param, dtrain, 1, num_class = nclasses)
|
bst <- xgb.train(param, dtrain, 1, num_class = nclasses)
|
||||||
})
|
})
|
||||||
|
|||||||
@ -75,7 +75,7 @@ def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
|
|||||||
return grad, hess
|
return grad, hess
|
||||||
|
|
||||||
|
|
||||||
def predict(booster, X):
|
def predict(booster: xgb.Booster, X):
|
||||||
'''A customized prediction function that converts raw prediction to
|
'''A customized prediction function that converts raw prediction to
|
||||||
target class.
|
target class.
|
||||||
|
|
||||||
@ -93,15 +93,34 @@ def predict(booster, X):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
|
||||||
|
y = dtrain.get_label()
|
||||||
|
# Like custom objective, the predt is untransformed leaf weight
|
||||||
|
assert predt.shape == (kRows, kClasses)
|
||||||
|
out = np.zeros(kRows)
|
||||||
|
for r in range(predt.shape[0]):
|
||||||
|
i = np.argmax(predt[r])
|
||||||
|
out[r] = i
|
||||||
|
|
||||||
|
assert y.shape == out.shape
|
||||||
|
|
||||||
|
errors = np.zeros(kRows)
|
||||||
|
errors[y != out] = 1.0
|
||||||
|
return 'PyMError', np.sum(errors) / kRows
|
||||||
|
|
||||||
|
|
||||||
def plot_history(custom_results, native_results):
|
def plot_history(custom_results, native_results):
|
||||||
fig, axs = plt.subplots(2, 1)
|
fig, axs = plt.subplots(2, 1)
|
||||||
ax0 = axs[0]
|
ax0 = axs[0]
|
||||||
ax1 = axs[1]
|
ax1 = axs[1]
|
||||||
|
|
||||||
|
pymerror = custom_results['train']['PyMError']
|
||||||
|
merror = native_results['train']['merror']
|
||||||
|
|
||||||
x = np.arange(0, kRounds, 1)
|
x = np.arange(0, kRounds, 1)
|
||||||
ax0.plot(x, custom_results['train']['merror'], label='Custom objective')
|
ax0.plot(x, pymerror, label='Custom objective')
|
||||||
ax0.legend()
|
ax0.legend()
|
||||||
ax1.plot(x, native_results['train']['merror'], label='multi:softmax')
|
ax1.plot(x, merror, label='multi:softmax')
|
||||||
ax1.legend()
|
ax1.legend()
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
@ -110,10 +129,12 @@ def plot_history(custom_results, native_results):
|
|||||||
def main(args):
|
def main(args):
|
||||||
custom_results = {}
|
custom_results = {}
|
||||||
# Use our custom objective function
|
# Use our custom objective function
|
||||||
booster_custom = xgb.train({'num_class': kClasses},
|
booster_custom = xgb.train({'num_class': kClasses,
|
||||||
|
'disable_default_eval_metric': True},
|
||||||
m,
|
m,
|
||||||
num_boost_round=kRounds,
|
num_boost_round=kRounds,
|
||||||
obj=softprob_obj,
|
obj=softprob_obj,
|
||||||
|
feval=merror,
|
||||||
evals_result=custom_results,
|
evals_result=custom_results,
|
||||||
evals=[(m, 'train')])
|
evals=[(m, 'train')])
|
||||||
|
|
||||||
@ -131,6 +152,8 @@ def main(args):
|
|||||||
# We are reimplementing the loss function in XGBoost, so it should
|
# We are reimplementing the loss function in XGBoost, so it should
|
||||||
# be the same for normal cases.
|
# be the same for normal cases.
|
||||||
assert np.all(predt_custom == predt_native)
|
assert np.all(predt_custom == predt_native)
|
||||||
|
np.testing.assert_allclose(custom_results['train']['PyMError'],
|
||||||
|
native_results['train']['merror'])
|
||||||
|
|
||||||
if args.plot != 0:
|
if args.plot != 0:
|
||||||
plot_history(custom_results, native_results)
|
plot_history(custom_results, native_results)
|
||||||
|
|||||||
@ -40,9 +40,9 @@ General Parameters
|
|||||||
|
|
||||||
- Number of parallel threads used to run XGBoost
|
- Number of parallel threads used to run XGBoost
|
||||||
|
|
||||||
* ``disable_default_eval_metric`` [default=0]
|
* ``disable_default_eval_metric`` [default=``false``]
|
||||||
|
|
||||||
- Flag to disable default metric. Set to >0 to disable.
|
- Flag to disable default metric. Set to 1 or ``true`` to disable.
|
||||||
|
|
||||||
* ``num_pbuffer`` [set automatically by XGBoost, no need to be set by user]
|
* ``num_pbuffer`` [set automatically by XGBoost, no need to be set by user]
|
||||||
|
|
||||||
|
|||||||
@ -1228,7 +1228,8 @@ class Booster(object):
|
|||||||
res = msg.value.decode()
|
res = msg.value.decode()
|
||||||
if feval is not None:
|
if feval is not None:
|
||||||
for dmat, evname in evals:
|
for dmat, evname in evals:
|
||||||
feval_ret = feval(self.predict(dmat, training=False), dmat)
|
feval_ret = feval(self.predict(dmat, training=False,
|
||||||
|
output_margin=True), dmat)
|
||||||
if isinstance(feval_ret, list):
|
if isinstance(feval_ret, list):
|
||||||
for name, val in feval_ret:
|
for name, val in feval_ret:
|
||||||
res += '\t%s-%s:%f' % (evname, name, val)
|
res += '\t%s-%s:%f' % (evname, name, val)
|
||||||
|
|||||||
@ -154,9 +154,9 @@ LearnerModelParam::LearnerModelParam(
|
|||||||
|
|
||||||
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||||
// data split mode, can be row, col, or none.
|
// data split mode, can be row, col, or none.
|
||||||
DataSplitMode dsplit;
|
DataSplitMode dsplit {DataSplitMode::kAuto};
|
||||||
// flag to disable default metric
|
// flag to disable default metric
|
||||||
int disable_default_eval_metric;
|
bool disable_default_eval_metric {false};
|
||||||
// FIXME(trivialfis): The following parameters belong to model itself, but can be
|
// FIXME(trivialfis): The following parameters belong to model itself, but can be
|
||||||
// specified by users. Move them to model parameter once we can get rid of binary IO.
|
// specified by users. Move them to model parameter once we can get rid of binary IO.
|
||||||
std::string booster;
|
std::string booster;
|
||||||
@ -171,7 +171,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
|||||||
.add_enum("row", DataSplitMode::kRow)
|
.add_enum("row", DataSplitMode::kRow)
|
||||||
.describe("Data split mode for distributed training.");
|
.describe("Data split mode for distributed training.");
|
||||||
DMLC_DECLARE_FIELD(disable_default_eval_metric)
|
DMLC_DECLARE_FIELD(disable_default_eval_metric)
|
||||||
.set_default(0)
|
.set_default(false)
|
||||||
.describe("Flag to disable default metric. Set to >0 to disable");
|
.describe("Flag to disable default metric. Set to >0 to disable");
|
||||||
DMLC_DECLARE_FIELD(booster)
|
DMLC_DECLARE_FIELD(booster)
|
||||||
.set_default("gbtree")
|
.set_default("gbtree")
|
||||||
@ -253,7 +253,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
void Configure() override {
|
void Configure() override {
|
||||||
// Varient of double checked lock
|
// Varient of double checked lock
|
||||||
if (!this->need_configuration_) { return; }
|
if (!this->need_configuration_) { return; }
|
||||||
std::lock_guard<std::mutex> gard(config_lock_);
|
std::lock_guard<std::mutex> guard(config_lock_);
|
||||||
if (!this->need_configuration_) { return; }
|
if (!this->need_configuration_) { return; }
|
||||||
|
|
||||||
monitor_.Start("Configure");
|
monitor_.Start("Configure");
|
||||||
|
|||||||
@ -37,11 +37,11 @@ class TestEarlyStopping(unittest.TestCase):
|
|||||||
eval_set=[(X_test, y_test)])
|
eval_set=[(X_test, y_test)])
|
||||||
assert clf3.best_score == 1
|
assert clf3.best_score == 1
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
|
||||||
def evalerror(self, preds, dtrain):
|
def evalerror(self, preds, dtrain):
|
||||||
from sklearn.metrics import mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
|
|
||||||
labels = dtrain.get_label()
|
labels = dtrain.get_label()
|
||||||
|
preds = 1.0 / (1.0 + np.exp(-preds))
|
||||||
return 'rmse', mean_squared_error(labels, preds)
|
return 'rmse', mean_squared_error(labels, preds)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user