Don't validate feature when number of rows is 0. (#6472)
This commit is contained in:
parent
55bdf084cb
commit
47b86180f6
@ -2025,6 +2025,9 @@ class Booster(object):
|
|||||||
Validate Booster and data's feature_names are identical.
|
Validate Booster and data's feature_names are identical.
|
||||||
Set feature_names and feature_types from DMatrix
|
Set feature_names and feature_types from DMatrix
|
||||||
"""
|
"""
|
||||||
|
if data.num_row() == 0:
|
||||||
|
return
|
||||||
|
|
||||||
if self.feature_names is None:
|
if self.feature_names is None:
|
||||||
self.feature_names = data.feature_names
|
self.feature_names = data.feature_names
|
||||||
self.feature_types = data.feature_types
|
self.feature_types = data.feature_types
|
||||||
|
|||||||
@ -763,7 +763,8 @@ async def _direct_predict_impl(client, data, predict_fn):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-statements
|
# pylint: disable=too-many-statements
|
||||||
async def _predict_async(client, global_config, model, data, missing=numpy.nan, **kwargs):
|
async def _predict_async(client, global_config, model, data, missing, validate_features,
|
||||||
|
**kwargs):
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
booster = model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
@ -779,7 +780,7 @@ async def _predict_async(client, global_config, model, data, missing=numpy.nan,
|
|||||||
with config.config_context(**global_config):
|
with config.config_context(**global_config):
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||||
predt = booster.predict(m, validate_features=False, **kwargs)
|
predt = booster.predict(m, validate_features=validate_features, **kwargs)
|
||||||
if is_df:
|
if is_df:
|
||||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||||
import cudf # pylint: disable=import-error
|
import cudf # pylint: disable=import-error
|
||||||
@ -821,7 +822,7 @@ async def _predict_async(client, global_config, model, data, missing=numpy.nan,
|
|||||||
)
|
)
|
||||||
predt = booster.predict(
|
predt = booster.predict(
|
||||||
data=local_part,
|
data=local_part,
|
||||||
validate_features=local_part.num_row() != 0,
|
validate_features=validate_features,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||||
ret = ((dask.delayed(predt), columns), order)
|
ret = ((dask.delayed(predt), columns), order)
|
||||||
@ -878,7 +879,7 @@ async def _predict_async(client, global_config, model, data, missing=numpy.nan,
|
|||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def predict(client, model, data, missing=numpy.nan, **kwargs):
|
def predict(client, model, data, missing=numpy.nan, validate_features=True, **kwargs):
|
||||||
'''Run prediction with a trained booster.
|
'''Run prediction with a trained booster.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -910,7 +911,7 @@ def predict(client, model, data, missing=numpy.nan, **kwargs):
|
|||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
global_config = config.get_config()
|
global_config = config.get_config()
|
||||||
return client.sync(_predict_async, client, global_config, model, data,
|
return client.sync(_predict_async, client, global_config, model, data,
|
||||||
missing=missing, **kwargs)
|
missing=missing, validate_features=validate_features, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def _inplace_predict_async(client, global_config, model, data,
|
async def _inplace_predict_async(client, global_config, model, data,
|
||||||
|
|||||||
@ -351,6 +351,30 @@ def test_sklearn_grid_search():
|
|||||||
assert len(means) == len(set(means))
|
assert len(means) == len(set(means))
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_dmatrix_training_continuation(client):
|
||||||
|
kRows, kCols = 1, 97
|
||||||
|
X = dd.from_array(np.random.randn(kRows, kCols))
|
||||||
|
y = dd.from_array(np.random.rand(kRows))
|
||||||
|
X.columns = ['X' + str(i) for i in range(0, 97)]
|
||||||
|
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
|
||||||
|
kRows += 1000
|
||||||
|
X = dd.from_array(np.random.randn(kRows, kCols), chunksize=10)
|
||||||
|
X.columns = ['X' + str(i) for i in range(0, 97)]
|
||||||
|
y = dd.from_array(np.random.rand(kRows), chunksize=10)
|
||||||
|
valid = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
|
||||||
|
out = xgb.dask.train(client, {'tree_method': 'hist'},
|
||||||
|
dtrain=dtrain, num_boost_round=2,
|
||||||
|
evals=[(valid, 'validation')])
|
||||||
|
|
||||||
|
out = xgb.dask.train(client, {'tree_method': 'hist'},
|
||||||
|
dtrain=dtrain, xgb_model=out['booster'],
|
||||||
|
num_boost_round=2,
|
||||||
|
evals=[(valid, 'validation')])
|
||||||
|
assert xgb.dask.predict(client, out, dtrain).compute().shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
def run_empty_dmatrix_reg(client, parameters):
|
def run_empty_dmatrix_reg(client, parameters):
|
||||||
def _check_outputs(out, predictions):
|
def _check_outputs(out, predictions):
|
||||||
assert isinstance(out['booster'], xgb.dask.Booster)
|
assert isinstance(out['booster'], xgb.dask.Booster)
|
||||||
@ -371,6 +395,19 @@ def run_empty_dmatrix_reg(client, parameters):
|
|||||||
data=dtrain).compute()
|
data=dtrain).compute()
|
||||||
_check_outputs(out, predictions)
|
_check_outputs(out, predictions)
|
||||||
|
|
||||||
|
# valid has more rows than train
|
||||||
|
kRows += 1
|
||||||
|
X = dd.from_array(np.random.randn(kRows, kCols))
|
||||||
|
y = dd.from_array(np.random.rand(kRows))
|
||||||
|
valid = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
out = xgb.dask.train(client, parameters,
|
||||||
|
dtrain=dtrain,
|
||||||
|
evals=[(valid, 'validation')],
|
||||||
|
num_boost_round=2)
|
||||||
|
predictions = xgb.dask.predict(client=client, model=out,
|
||||||
|
data=dtrain).compute()
|
||||||
|
_check_outputs(out, predictions)
|
||||||
|
|
||||||
# train has more rows than evals
|
# train has more rows than evals
|
||||||
valid = dtrain
|
valid = dtrain
|
||||||
kRows += 1
|
kRows += 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user