Don't set_params at the end of set_state. (#4947)
* Don't set_params at the end of set_state. * Also fix another issue found in dask prediction. * Add note about prediction. Don't support other prediction modes at the moment.
This commit is contained in:
parent
2ebdec8aa6
commit
7e72a12871
@ -32,6 +32,7 @@ def main(client):
|
|||||||
|
|
||||||
# you can pass output directly into `predict` too.
|
# you can pass output directly into `predict` too.
|
||||||
prediction = xgb.dask.predict(client, bst, dtrain)
|
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||||
|
prediction = prediction.compute()
|
||||||
print('Evaluation history:', history)
|
print('Evaluation history:', history)
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|||||||
@ -1125,7 +1125,6 @@ class Booster(object):
|
|||||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(handle, ptr, length))
|
_check_call(_LIB.XGBoosterLoadModelFromBuffer(handle, ptr, length))
|
||||||
state['handle'] = handle
|
state['handle'] = handle
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
self.set_param({'seed': 0})
|
|
||||||
|
|
||||||
def __copy__(self):
|
def __copy__(self):
|
||||||
return self.__deepcopy__(None)
|
return self.__deepcopy__(None)
|
||||||
|
|||||||
@ -395,6 +395,10 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
def predict(client, model, data, *args):
|
def predict(client, model, data, *args):
|
||||||
'''Run prediction with a trained booster.
|
'''Run prediction with a trained booster.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Only default prediction mode is supported right now.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
client: dask.distributed.Client
|
client: dask.distributed.Client
|
||||||
@ -445,8 +449,8 @@ def predict(client, model, data, *args):
|
|||||||
'''Get shape of data in each worker.'''
|
'''Get shape of data in each worker.'''
|
||||||
logging.info('Trying to get data shape on %d', worker_id)
|
logging.info('Trying to get data shape on %d', worker_id)
|
||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
rows, cols = data.get_worker_data_shape(worker)
|
rows, _ = data.get_worker_data_shape(worker)
|
||||||
return rows, cols
|
return rows, 1 # default is 1
|
||||||
|
|
||||||
# Constructing a dask array from list of numpy arrays
|
# Constructing a dask array from list of numpy arrays
|
||||||
# See https://docs.dask.org/en/latest/array-creation.html
|
# See https://docs.dask.org/en/latest/array-creation.html
|
||||||
@ -457,7 +461,7 @@ def predict(client, model, data, *args):
|
|||||||
shapes = client.gather(futures_shape)
|
shapes = client.gather(futures_shape)
|
||||||
arrays = []
|
arrays = []
|
||||||
for i in range(len(futures_shape)):
|
for i in range(len(futures_shape)):
|
||||||
arrays.append(da.from_delayed(futures[i], shape=shapes[i],
|
arrays.append(da.from_delayed(futures[i], shape=(shapes[i][0], ),
|
||||||
dtype=numpy.float32))
|
dtype=numpy.float32))
|
||||||
predictions = da.concatenate(arrays, axis=0)
|
predictions = da.concatenate(arrays, axis=0)
|
||||||
return predictions
|
return predictions
|
||||||
|
|||||||
@ -40,3 +40,6 @@ def test_dask_dataframe(client):
|
|||||||
|
|
||||||
assert isinstance(out['booster'], dxgb.Booster)
|
assert isinstance(out['booster'], dxgb.Booster)
|
||||||
assert len(out['history']['X']['rmse']) == 2
|
assert len(out['history']['X']['rmse']) == 2
|
||||||
|
|
||||||
|
predictions = dxgb.predict(out, dtrain)
|
||||||
|
predictions = predictions.compute()
|
||||||
|
|||||||
48
tests/python/test_pickling.py
Normal file
48
tests/python/test_pickling.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import xgboost as xgb
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
kRows = 100
|
||||||
|
kCols = 10
|
||||||
|
|
||||||
|
|
||||||
|
def generate_data():
|
||||||
|
X = np.random.randn(kRows, kCols)
|
||||||
|
y = np.random.randn(kRows)
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_pickling():
|
||||||
|
xgb_params = {
|
||||||
|
'verbosity': 0,
|
||||||
|
'nthread': 1,
|
||||||
|
'tree_method': 'hist'
|
||||||
|
}
|
||||||
|
|
||||||
|
X, y = generate_data()
|
||||||
|
dtrain = xgb.DMatrix(X, y)
|
||||||
|
bst = xgb.train(xgb_params, dtrain)
|
||||||
|
|
||||||
|
dump_0 = bst.get_dump(dump_format='json')
|
||||||
|
assert dump_0
|
||||||
|
|
||||||
|
filename = 'model.pkl'
|
||||||
|
|
||||||
|
with open(filename, 'wb') as fd:
|
||||||
|
pickle.dump(bst, fd)
|
||||||
|
|
||||||
|
with open(filename, 'rb') as fd:
|
||||||
|
bst = pickle.load(fd)
|
||||||
|
|
||||||
|
with open(filename, 'wb') as fd:
|
||||||
|
pickle.dump(bst, fd)
|
||||||
|
|
||||||
|
with open(filename, 'rb') as fd:
|
||||||
|
bst = pickle.load(fd)
|
||||||
|
|
||||||
|
assert bst.get_dump(dump_format='json') == dump_0
|
||||||
|
|
||||||
|
if os.path.exists(filename):
|
||||||
|
os.remove(filename)
|
||||||
@ -43,14 +43,17 @@ def test_from_dask_dataframe(client):
|
|||||||
|
|
||||||
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||||
|
|
||||||
|
assert prediction.ndim == 1
|
||||||
assert isinstance(prediction, da.Array)
|
assert isinstance(prediction, da.Array)
|
||||||
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# evals_result is not supported in dask interface.
|
# evals_result is not supported in dask interface.
|
||||||
xgb.dask.train(
|
xgb.dask.train(
|
||||||
client, {}, dtrain, num_boost_round=2, evals_result={})
|
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||||
|
|
||||||
|
prediction = prediction.compute() # force prediction to be computed
|
||||||
|
|
||||||
|
|
||||||
def test_from_dask_array(client):
|
def test_from_dask_array(client):
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
@ -59,10 +62,12 @@ def test_from_dask_array(client):
|
|||||||
result = xgb.dask.train(client, {}, dtrain)
|
result = xgb.dask.train(client, {}, dtrain)
|
||||||
|
|
||||||
prediction = xgb.dask.predict(client, result, dtrain)
|
prediction = xgb.dask.predict(client, result, dtrain)
|
||||||
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
assert isinstance(prediction, da.Array)
|
assert isinstance(prediction, da.Array)
|
||||||
|
|
||||||
|
prediction = prediction.compute() # force prediction to be computed
|
||||||
|
|
||||||
|
|
||||||
def test_regressor(client):
|
def test_regressor(client):
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
@ -72,7 +77,8 @@ def test_regressor(client):
|
|||||||
regressor.fit(X, y, eval_set=[(X, y)])
|
regressor.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = regressor.predict(X)
|
prediction = regressor.predict(X)
|
||||||
|
|
||||||
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
|
assert prediction.ndim == 1
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
history = regressor.evals_result()
|
history = regressor.evals_result()
|
||||||
|
|
||||||
@ -91,7 +97,8 @@ def test_classifier(client):
|
|||||||
classifier.fit(X, y, eval_set=[(X, y)])
|
classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
prediction = classifier.predict(X)
|
prediction = classifier.predict(X)
|
||||||
|
|
||||||
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
|
assert prediction.ndim == 1
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
history = classifier.evals_result()
|
history = classifier.evals_result()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user