Back port fixes to 1.2 (#6002)
* Fix sklearn doc. (#5980) * Enforce tree order in JSON. (#5974) * Make JSON model IO more future proof by using tree id in model loading. * Fix dask predict shape infer. (#5989) * [Breaking] Fix .predict() method and add .predict_proba() in xgboost.dask.DaskXGBClassifier (#5986)
This commit is contained in:
parent
7856da5827
commit
936a854baa
@ -738,7 +738,8 @@ async def _predict_async(client: Client, model, data, *args,
|
|||||||
predt = booster.predict(data=local_x,
|
predt = booster.predict(data=local_x,
|
||||||
validate_features=local_x.num_row() != 0,
|
validate_features=local_x.num_row() != 0,
|
||||||
*args)
|
*args)
|
||||||
ret = (delayed(predt), order)
|
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||||
|
ret = ((delayed(predt), columns), order)
|
||||||
predictions.append(ret)
|
predictions.append(ret)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@ -775,7 +776,9 @@ async def _predict_async(client: Client, model, data, *args,
|
|||||||
# See https://docs.dask.org/en/latest/array-creation.html
|
# See https://docs.dask.org/en/latest/array-creation.html
|
||||||
arrays = []
|
arrays = []
|
||||||
for i, shape in enumerate(shapes):
|
for i, shape in enumerate(shapes):
|
||||||
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
|
arrays.append(da.from_delayed(
|
||||||
|
results[i][0], shape=(shape[0],)
|
||||||
|
if results[i][1] == 1 else (shape[0], results[i][1]),
|
||||||
dtype=numpy.float32))
|
dtype=numpy.float32))
|
||||||
predictions = await da.concatenate(arrays, axis=0)
|
predictions = await da.concatenate(arrays, axis=0)
|
||||||
return predictions
|
return predictions
|
||||||
@ -978,6 +981,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
def client(self, clt):
|
def client(self, clt):
|
||||||
self._client = clt
|
self._client = clt
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||||
['estimators', 'model'])
|
['estimators', 'model'])
|
||||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||||
@ -1032,9 +1036,6 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
['estimators', 'model']
|
['estimators', 'model']
|
||||||
)
|
)
|
||||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring
|
|
||||||
_client = None
|
|
||||||
|
|
||||||
async def _fit_async(self, X, y,
|
async def _fit_async(self, X, y,
|
||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
@ -1078,13 +1079,34 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
return self.client.sync(self._fit_async, X, y, sample_weights,
|
return self.client.sync(self._fit_async, X, y, sample_weights,
|
||||||
eval_set, sample_weight_eval_set, verbose)
|
eval_set, sample_weight_eval_set, verbose)
|
||||||
|
|
||||||
async def _predict_async(self, data):
|
async def _predict_proba_async(self, data):
|
||||||
|
_assert_dask_support()
|
||||||
|
|
||||||
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
||||||
missing=self.missing)
|
missing=self.missing)
|
||||||
pred_probs = await predict(client=self.client,
|
pred_probs = await predict(client=self.client,
|
||||||
model=self.get_booster(), data=test_dmatrix)
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
return pred_probs
|
return pred_probs
|
||||||
|
|
||||||
|
def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring
|
||||||
|
_assert_dask_support()
|
||||||
|
return self.client.sync(self._predict_proba_async, data)
|
||||||
|
|
||||||
|
async def _predict_async(self, data):
|
||||||
|
_assert_dask_support()
|
||||||
|
|
||||||
|
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
||||||
|
missing=self.missing)
|
||||||
|
pred_probs = await predict(client=self.client,
|
||||||
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
|
|
||||||
|
if self.n_classes_ == 2:
|
||||||
|
preds = (pred_probs > 0.5).astype(int)
|
||||||
|
else:
|
||||||
|
preds = da.argmax(pred_probs, axis=1)
|
||||||
|
|
||||||
|
return preds
|
||||||
|
|
||||||
def predict(self, data): # pylint: disable=arguments-differ
|
def predict(self, data): # pylint: disable=arguments-differ
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(self._predict_async, data)
|
return self.client.sync(self._predict_async, data)
|
||||||
|
|||||||
@ -77,7 +77,7 @@ __model_doc = '''
|
|||||||
gamma : float
|
gamma : float
|
||||||
Minimum loss reduction required to make a further partition on a leaf
|
Minimum loss reduction required to make a further partition on a leaf
|
||||||
node of the tree.
|
node of the tree.
|
||||||
min_child_weight : int
|
min_child_weight : float
|
||||||
Minimum sum of instance weight(hessian) needed in a child.
|
Minimum sum of instance weight(hessian) needed in a child.
|
||||||
max_delta_step : int
|
max_delta_step : int
|
||||||
Maximum delta step we allow each tree's weight estimation to be.
|
Maximum delta step we allow each tree's weight estimation to be.
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2019 by Contributors
|
* Copyright 2019-2020 by Contributors
|
||||||
*/
|
*/
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "gbtree_model.h"
|
#include "gbtree_model.h"
|
||||||
@ -41,15 +43,14 @@ void GBTreeModel::SaveModel(Json* p_out) const {
|
|||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||||
out["gbtree_model_param"] = ToJson(param);
|
out["gbtree_model_param"] = ToJson(param);
|
||||||
std::vector<Json> trees_json;
|
std::vector<Json> trees_json(trees.size());
|
||||||
size_t t = 0;
|
|
||||||
for (auto const& tree : trees) {
|
for (size_t t = 0; t < trees.size(); ++t) {
|
||||||
|
auto const& tree = trees[t];
|
||||||
Json tree_json{Object()};
|
Json tree_json{Object()};
|
||||||
tree->SaveModel(&tree_json);
|
tree->SaveModel(&tree_json);
|
||||||
// The field is not used in XGBoost, but might be useful for external project.
|
tree_json["id"] = Integer(static_cast<Integer::Int>(t));
|
||||||
tree_json["id"] = Integer(t);
|
trees_json[t] = std::move(tree_json);
|
||||||
trees_json.emplace_back(tree_json);
|
|
||||||
t++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Json> tree_info_json(tree_info.size());
|
std::vector<Json> tree_info_json(tree_info.size());
|
||||||
@ -70,9 +71,10 @@ void GBTreeModel::LoadModel(Json const& in) {
|
|||||||
auto const& trees_json = get<Array const>(in["trees"]);
|
auto const& trees_json = get<Array const>(in["trees"]);
|
||||||
trees.resize(trees_json.size());
|
trees.resize(trees_json.size());
|
||||||
|
|
||||||
for (size_t t = 0; t < trees.size(); ++t) {
|
for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT
|
||||||
trees[t].reset( new RegTree() );
|
auto tree_id = get<Integer>(trees_json[t]["id"]);
|
||||||
trees[t]->LoadModel(trees_json[t]);
|
trees.at(tree_id).reset(new RegTree());
|
||||||
|
trees.at(tree_id)->LoadModel(trees_json[t]);
|
||||||
}
|
}
|
||||||
|
|
||||||
tree_info.resize(param.num_trees);
|
tree_info.resize(param.num_trees);
|
||||||
|
|||||||
@ -148,7 +148,16 @@ TEST(Learner, JsonModelIO) {
|
|||||||
Json out { Object() };
|
Json out { Object() };
|
||||||
learner->SaveModel(&out);
|
learner->SaveModel(&out);
|
||||||
|
|
||||||
learner->LoadModel(out);
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
|
||||||
|
std::ofstream fout (tmpdir.path + "/model.json");
|
||||||
|
fout << out;
|
||||||
|
fout.close();
|
||||||
|
|
||||||
|
auto loaded_str = common::LoadSequentialFile(tmpdir.path + "/model.json");
|
||||||
|
Json loaded = Json::Load(StringView{loaded_str.c_str(), loaded_str.size()});
|
||||||
|
|
||||||
|
learner->LoadModel(loaded);
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
|
|
||||||
Json new_in { Object() };
|
Json new_in { Object() };
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from sklearn.datasets import make_classification
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -36,7 +37,7 @@ def generate_array():
|
|||||||
|
|
||||||
|
|
||||||
def test_from_dask_dataframe():
|
def test_from_dask_dataframe():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ def test_from_dask_dataframe():
|
|||||||
|
|
||||||
|
|
||||||
def test_from_dask_array():
|
def test_from_dask_array():
|
||||||
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
|
with LocalCluster(n_workers=kWorkers, threads_per_worker=5) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
dtrain = DaskDMatrix(client, X, y)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
@ -104,8 +105,28 @@ def test_from_dask_array():
|
|||||||
assert np.all(single_node_predt == from_arr.compute())
|
assert np.all(single_node_predt == from_arr.compute())
|
||||||
|
|
||||||
|
|
||||||
|
def test_dask_predict_shape_infer():
|
||||||
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
X, y = make_classification(n_samples=1000, n_informative=5,
|
||||||
|
n_classes=3)
|
||||||
|
X_ = dd.from_array(X, chunksize=100)
|
||||||
|
y_ = dd.from_array(y, chunksize=100)
|
||||||
|
dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_)
|
||||||
|
|
||||||
|
model = xgb.dask.train(
|
||||||
|
client,
|
||||||
|
{"objective": "multi:softprob", "num_class": 3},
|
||||||
|
dtrain=dtrain
|
||||||
|
)
|
||||||
|
|
||||||
|
preds = xgb.dask.predict(client, model, dtrain)
|
||||||
|
assert preds.shape[0] == preds.compute().shape[0]
|
||||||
|
assert preds.shape[1] == preds.compute().shape[1]
|
||||||
|
|
||||||
|
|
||||||
def test_dask_missing_value_reg():
|
def test_dask_missing_value_reg():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X_0 = np.ones((20 // 2, kCols))
|
X_0 = np.ones((20 // 2, kCols))
|
||||||
X_1 = np.zeros((20 // 2, kCols))
|
X_1 = np.zeros((20 // 2, kCols))
|
||||||
@ -144,19 +165,19 @@ def test_dask_missing_value_cls():
|
|||||||
missing=0.0)
|
missing=0.0)
|
||||||
cls.client = client
|
cls.client = client
|
||||||
cls.fit(X, y, eval_set=[(X, y)])
|
cls.fit(X, y, eval_set=[(X, y)])
|
||||||
dd_predt = cls.predict(X).compute()
|
dd_pred_proba = cls.predict_proba(X).compute()
|
||||||
|
|
||||||
np_X = X.compute()
|
np_X = X.compute()
|
||||||
np_predt = cls.get_booster().predict(
|
np_pred_proba = cls.get_booster().predict(
|
||||||
xgb.DMatrix(np_X, missing=0.0))
|
xgb.DMatrix(np_X, missing=0.0))
|
||||||
np.testing.assert_allclose(np_predt, dd_predt)
|
np.testing.assert_allclose(np_pred_proba, dd_pred_proba)
|
||||||
|
|
||||||
cls = xgb.dask.DaskXGBClassifier()
|
cls = xgb.dask.DaskXGBClassifier()
|
||||||
assert hasattr(cls, 'missing')
|
assert hasattr(cls, 'missing')
|
||||||
|
|
||||||
|
|
||||||
def test_dask_regressor():
|
def test_dask_regressor():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||||
@ -178,7 +199,7 @@ def test_dask_regressor():
|
|||||||
|
|
||||||
|
|
||||||
def test_dask_classifier():
|
def test_dask_classifier():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
y = (y * 10).astype(np.int32)
|
y = (y * 10).astype(np.int32)
|
||||||
@ -201,7 +222,18 @@ def test_dask_classifier():
|
|||||||
assert len(list(history['validation_0'])) == 1
|
assert len(list(history['validation_0'])) == 1
|
||||||
assert len(history['validation_0']['merror']) == 2
|
assert len(history['validation_0']['merror']) == 2
|
||||||
|
|
||||||
|
# Test .predict_proba()
|
||||||
|
probas = classifier.predict_proba(X)
|
||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
|
assert probas.ndim == 2
|
||||||
|
assert probas.shape[0] == kRows
|
||||||
|
assert probas.shape[1] == 10
|
||||||
|
|
||||||
|
cls_booster = classifier.get_booster()
|
||||||
|
single_node_proba = cls_booster.inplace_predict(X.compute())
|
||||||
|
|
||||||
|
np.testing.assert_allclose(single_node_proba,
|
||||||
|
probas.compute())
|
||||||
|
|
||||||
# Test with dataframe.
|
# Test with dataframe.
|
||||||
X_d = dd.from_dask_array(X)
|
X_d = dd.from_dask_array(X)
|
||||||
@ -218,7 +250,7 @@ def test_dask_classifier():
|
|||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_sklearn_grid_search():
|
def test_sklearn_grid_search():
|
||||||
from sklearn.model_selection import GridSearchCV
|
from sklearn.model_selection import GridSearchCV
|
||||||
with LocalCluster(n_workers=4) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
||||||
@ -292,7 +324,9 @@ def run_empty_dmatrix_cls(client, parameters):
|
|||||||
evals=[(dtrain, 'validation')],
|
evals=[(dtrain, 'validation')],
|
||||||
num_boost_round=2)
|
num_boost_round=2)
|
||||||
predictions = xgb.dask.predict(client=client, model=out,
|
predictions = xgb.dask.predict(client=client, model=out,
|
||||||
data=dtrain).compute()
|
data=dtrain)
|
||||||
|
assert predictions.shape[1] == n_classes
|
||||||
|
predictions = predictions.compute()
|
||||||
_check_outputs(out, predictions)
|
_check_outputs(out, predictions)
|
||||||
|
|
||||||
# train has more rows than evals
|
# train has more rows than evals
|
||||||
@ -315,7 +349,7 @@ def run_empty_dmatrix_cls(client, parameters):
|
|||||||
# environment and Exact doesn't support it.
|
# environment and Exact doesn't support it.
|
||||||
|
|
||||||
def test_empty_dmatrix_hist():
|
def test_empty_dmatrix_hist():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
parameters = {'tree_method': 'hist'}
|
parameters = {'tree_method': 'hist'}
|
||||||
run_empty_dmatrix_reg(client, parameters)
|
run_empty_dmatrix_reg(client, parameters)
|
||||||
@ -323,7 +357,7 @@ def test_empty_dmatrix_hist():
|
|||||||
|
|
||||||
|
|
||||||
def test_empty_dmatrix_approx():
|
def test_empty_dmatrix_approx():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
parameters = {'tree_method': 'approx'}
|
parameters = {'tree_method': 'approx'}
|
||||||
run_empty_dmatrix_reg(client, parameters)
|
run_empty_dmatrix_reg(client, parameters)
|
||||||
@ -397,7 +431,13 @@ async def run_dask_classifier_asyncio(scheduler_address):
|
|||||||
assert len(list(history['validation_0'])) == 1
|
assert len(list(history['validation_0'])) == 1
|
||||||
assert len(history['validation_0']['merror']) == 2
|
assert len(history['validation_0']['merror']) == 2
|
||||||
|
|
||||||
|
# Test .predict_proba()
|
||||||
|
probas = await classifier.predict_proba(X)
|
||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
|
assert probas.ndim == 2
|
||||||
|
assert probas.shape[0] == kRows
|
||||||
|
assert probas.shape[1] == 10
|
||||||
|
|
||||||
|
|
||||||
# Test with dataframe.
|
# Test with dataframe.
|
||||||
X_d = dd.from_dask_array(X)
|
X_d = dd.from_dask_array(X)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user