Don't validate feature when number of rows is 0. (#6472)
This commit is contained in:
@@ -2025,6 +2025,9 @@ class Booster(object):
|
||||
Validate Booster and data's feature_names are identical.
|
||||
Set feature_names and feature_types from DMatrix
|
||||
"""
|
||||
if data.num_row() == 0:
|
||||
return
|
||||
|
||||
if self.feature_names is None:
|
||||
self.feature_names = data.feature_names
|
||||
self.feature_types = data.feature_types
|
||||
|
||||
@@ -763,7 +763,8 @@ async def _direct_predict_impl(client, data, predict_fn):
|
||||
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def _predict_async(client, global_config, model, data, missing=numpy.nan, **kwargs):
|
||||
async def _predict_async(client, global_config, model, data, missing, validate_features,
|
||||
**kwargs):
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
@@ -779,7 +780,7 @@ async def _predict_async(client, global_config, model, data, missing=numpy.nan,
|
||||
with config.config_context(**global_config):
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, validate_features=False, **kwargs)
|
||||
predt = booster.predict(m, validate_features=validate_features, **kwargs)
|
||||
if is_df:
|
||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
@@ -821,7 +822,7 @@ async def _predict_async(client, global_config, model, data, missing=numpy.nan,
|
||||
)
|
||||
predt = booster.predict(
|
||||
data=local_part,
|
||||
validate_features=local_part.num_row() != 0,
|
||||
validate_features=validate_features,
|
||||
**kwargs)
|
||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||
ret = ((dask.delayed(predt), columns), order)
|
||||
@@ -878,7 +879,7 @@ async def _predict_async(client, global_config, model, data, missing=numpy.nan,
|
||||
return predictions
|
||||
|
||||
|
||||
def predict(client, model, data, missing=numpy.nan, **kwargs):
|
||||
def predict(client, model, data, missing=numpy.nan, validate_features=True, **kwargs):
|
||||
'''Run prediction with a trained booster.
|
||||
|
||||
.. note::
|
||||
@@ -910,7 +911,7 @@ def predict(client, model, data, missing=numpy.nan, **kwargs):
|
||||
client = _xgb_get_client(client)
|
||||
global_config = config.get_config()
|
||||
return client.sync(_predict_async, client, global_config, model, data,
|
||||
missing=missing, **kwargs)
|
||||
missing=missing, validate_features=validate_features, **kwargs)
|
||||
|
||||
|
||||
async def _inplace_predict_async(client, global_config, model, data,
|
||||
|
||||
Reference in New Issue
Block a user