[Breaking] Set output margin to True for custom objective. (#5564)

* Set output margin to True for custom objective in Python and R.

* Add a demo for writing multi-class custom objective function.

* Run tests on selected demos.
This commit is contained in:
Jiaming Yuan
2020-04-20 20:44:12 +08:00
committed by GitHub
parent fcbedcedf8
commit 9c1103e06c
10 changed files with 262 additions and 22 deletions

View File

@@ -15,6 +15,7 @@ import numpy as np
import xgboost as xgb
from typing import Tuple, Dict, List
from time import time
import argparse
import matplotlib
from matplotlib import pyplot as plt
@@ -150,12 +151,7 @@ def py_rmsle(dtrain: xgb.DMatrix, dtest: xgb.DMatrix) -> Dict:
return results
if __name__ == '__main__':
dtrain, dtest = generate_data()
rmse_evals = native_rmse(dtrain, dtest)
rmsle_evals = native_rmsle(dtrain, dtest)
py_rmsle_evals = py_rmsle(dtrain, dtest)
def plot_history(rmse_evals, rmsle_evals, py_rmsle_evals):
fig, axs = plt.subplots(3, 1)
ax0: matplotlib.axes.Axes = axs[0]
ax1: matplotlib.axes.Axes = axs[1]
@@ -177,3 +173,25 @@ if __name__ == '__main__':
plt.show()
plt.close()
def main(args):
dtrain, dtest = generate_data()
rmse_evals = native_rmse(dtrain, dtest)
rmsle_evals = native_rmsle(dtrain, dtest)
py_rmsle_evals = py_rmsle(dtrain, dtest)
if args.plot != 0:
plot_history(rmse_evals, rmsle_evals, py_rmsle_evals)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Arguments for custom RMSLE objective function demo.')
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)