Prevent training without setting up caches. (#4066)
* Prevent training without setting up caches. * Add warning for internal functions. * Check number of features. * Address reviewer's comment.
This commit is contained in:
parent
7a652a8c64
commit
1088dff42c
@ -161,6 +161,10 @@ A saved model can be loaded as follows:
|
|||||||
bst = xgb.Booster({'nthread': 4}) # init model
|
bst = xgb.Booster({'nthread': 4}) # init model
|
||||||
bst.load_model('model.bin') # load data
|
bst.load_model('model.bin') # load data
|
||||||
|
|
||||||
|
Methods including `update` and `boost` from `xgboost.Booster` are designed for
|
||||||
|
internal usage only. The wrapper function `xgboost.train` does some
|
||||||
|
pre-configuration including setting up caches and some other parameters.
|
||||||
|
|
||||||
Early Stopping
|
Early Stopping
|
||||||
--------------
|
--------------
|
||||||
If you have a validation set, you can use early stopping to find the optimal number of boosting rounds.
|
If you have a validation set, you can use early stopping to find the optimal number of boosting rounds.
|
||||||
@ -215,4 +219,3 @@ When you use ``IPython``, you can use the :py:meth:`xgboost.to_graphviz` functio
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
xgb.to_graphviz(bst, num_trees=2)
|
xgb.to_graphviz(bst, num_trees=2)
|
||||||
|
|
||||||
|
|||||||
@ -1041,8 +1041,8 @@ class Booster(object):
|
|||||||
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val))))
|
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val))))
|
||||||
|
|
||||||
def update(self, dtrain, iteration, fobj=None):
|
def update(self, dtrain, iteration, fobj=None):
|
||||||
"""
|
"""Update for one iteration, with objective function calculated
|
||||||
Update for one iteration, with objective function calculated internally.
|
internally. This function should not be called directly by users.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -1052,6 +1052,7 @@ class Booster(object):
|
|||||||
Current iteration number.
|
Current iteration number.
|
||||||
fobj : function
|
fobj : function
|
||||||
Customized objective function.
|
Customized objective function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||||
@ -1066,8 +1067,9 @@ class Booster(object):
|
|||||||
self.boost(dtrain, grad, hess)
|
self.boost(dtrain, grad, hess)
|
||||||
|
|
||||||
def boost(self, dtrain, grad, hess):
|
def boost(self, dtrain, grad, hess):
|
||||||
"""
|
"""Boost the booster for one iteration, with customized gradient
|
||||||
Boost the booster for one iteration, with customized gradient statistics.
|
statistics. Like :func:`xgboost.core.Booster.update`, this
|
||||||
|
function should not be called directly by users.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -1077,6 +1079,7 @@ class Booster(object):
|
|||||||
The first order of gradient.
|
The first order of gradient.
|
||||||
hess : list
|
hess : list
|
||||||
The second order of gradient.
|
The second order of gradient.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if len(grad) != len(hess):
|
if len(grad) != len(hess):
|
||||||
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
|
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
|
||||||
|
|||||||
@ -703,6 +703,8 @@ class LearnerImpl : public Learner {
|
|||||||
if (num_feature > mparam_.num_feature) {
|
if (num_feature > mparam_.num_feature) {
|
||||||
mparam_.num_feature = num_feature;
|
mparam_.num_feature = num_feature;
|
||||||
}
|
}
|
||||||
|
CHECK_NE(mparam_.num_feature, 0)
|
||||||
|
<< "0 feature is supplied. Are you using raw Booster?";
|
||||||
// setup
|
// setup
|
||||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||||
CHECK(obj_ == nullptr && gbm_ == nullptr);
|
CHECK(obj_ == nullptr && gbm_ == nullptr);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user