Prevent training without setting up caches. (#4066)

* Prevent training without setting up caches.

* Add warning for internal functions.
* Check number of features.

* Address reviewer's comment.
This commit is contained in:
Jiaming Yuan 2019-02-03 17:03:29 +08:00 committed by Philip Hyunsu Cho
parent 7a652a8c64
commit 1088dff42c
3 changed files with 13 additions and 5 deletions

View File

@ -161,6 +161,10 @@ A saved model can be loaded as follows:
bst = xgb.Booster({'nthread': 4}) # init model
bst.load_model('model.bin') # load data
Methods including `update` and `boost` from `xgboost.Booster` are designed for
internal usage only. The wrapper function `xgboost.train` does some
pre-configuration including setting up caches and some other parameters.
Early Stopping
--------------
If you have a validation set, you can use early stopping to find the optimal number of boosting rounds.
@ -215,4 +219,3 @@ When you use ``IPython``, you can use the :py:meth:`xgboost.to_graphviz` functio
.. code-block:: python
xgb.to_graphviz(bst, num_trees=2)

View File

@ -1041,8 +1041,8 @@ class Booster(object):
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val))))
def update(self, dtrain, iteration, fobj=None):
"""
Update for one iteration, with objective function calculated internally.
"""Update for one iteration, with objective function calculated
internally. This function should not be called directly by users.
Parameters
----------
@ -1052,6 +1052,7 @@ class Booster(object):
Current iteration number.
fobj : function
Customized objective function.
"""
if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
@ -1066,8 +1067,9 @@ class Booster(object):
self.boost(dtrain, grad, hess)
def boost(self, dtrain, grad, hess):
"""
Boost the booster for one iteration, with customized gradient statistics.
"""Boost the booster for one iteration, with customized gradient
statistics. Like :func:`xgboost.core.Booster.update`, this
function should not be called directly by users.
Parameters
----------
@ -1077,6 +1079,7 @@ class Booster(object):
The first order of gradient.
hess : list
The second order of gradient.
"""
if len(grad) != len(hess):
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))

View File

@ -703,6 +703,8 @@ class LearnerImpl : public Learner {
if (num_feature > mparam_.num_feature) {
mparam_.num_feature = num_feature;
}
CHECK_NE(mparam_.num_feature, 0)
<< "0 feature is supplied. Are you using raw Booster?";
// setup
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
CHECK(obj_ == nullptr && gbm_ == nullptr);