diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index b97855cbe..a08c21367 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -738,7 +738,8 @@ async def _predict_async(client: Client, model, data, *args, predt = booster.predict(data=local_x, validate_features=local_x.num_row() != 0, *args) - ret = (delayed(predt), order) + columns = 1 if len(predt.shape) == 1 else predt.shape[1] + ret = ((delayed(predt), columns), order) predictions.append(ret) return predictions @@ -775,8 +776,10 @@ async def _predict_async(client: Client, model, data, *args, # See https://docs.dask.org/en/latest/array-creation.html arrays = [] for i, shape in enumerate(shapes): - arrays.append(da.from_delayed(results[i], shape=(shape[0], ), - dtype=numpy.float32)) + arrays.append(da.from_delayed( + results[i][0], shape=(shape[0],) + if results[i][1] == 1 else (shape[0], results[i][1]), + dtype=numpy.float32)) predictions = await da.concatenate(arrays, axis=0) return predictions @@ -978,6 +981,7 @@ class DaskScikitLearnBase(XGBModel): def client(self, clt): self._client = clt + @xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", ['estimators', 'model']) class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): @@ -1032,9 +1036,6 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): ['estimators', 'model'] ) class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): - # pylint: disable=missing-docstring - _client = None - async def _fit_async(self, X, y, sample_weights=None, eval_set=None, @@ -1078,13 +1079,34 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): return self.client.sync(self._fit_async, X, y, sample_weights, eval_set, sample_weight_eval_set, verbose) - async def _predict_async(self, data): + async def _predict_proba_async(self, data): + _assert_dask_support() + test_dmatrix = await DaskDMatrix(client=self.client, data=data, missing=self.missing) pred_probs = await predict(client=self.client, model=self.get_booster(), data=test_dmatrix) return pred_probs + def predict_proba(self, data): # pylint: disable=arguments-differ,missing-docstring + _assert_dask_support() + return self.client.sync(self._predict_proba_async, data) + + async def _predict_async(self, data): + _assert_dask_support() + + test_dmatrix = await DaskDMatrix(client=self.client, data=data, + missing=self.missing) + pred_probs = await predict(client=self.client, + model=self.get_booster(), data=test_dmatrix) + + if self.n_classes_ == 2: + preds = (pred_probs > 0.5).astype(int) + else: + preds = da.argmax(pred_probs, axis=1) + + return preds + def predict(self, data): # pylint: disable=arguments-differ _assert_dask_support() return self.client.sync(self._predict_async, data) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 1f3033f2d..f533f7f34 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -77,7 +77,7 @@ __model_doc = ''' gamma : float Minimum loss reduction required to make a further partition on a leaf node of the tree. - min_child_weight : int + min_child_weight : float Minimum sum of instance weight(hessian) needed in a child. max_delta_step : int Maximum delta step we allow each tree's weight estimation to be. diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index a53346797..8ebd8284c 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -1,6 +1,8 @@ /*! - * Copyright 2019 by Contributors + * Copyright 2019-2020 by Contributors */ +#include + #include "xgboost/json.h" #include "xgboost/logging.h" #include "gbtree_model.h" @@ -41,15 +43,14 @@ void GBTreeModel::SaveModel(Json* p_out) const { auto& out = *p_out; CHECK_EQ(param.num_trees, static_cast(trees.size())); out["gbtree_model_param"] = ToJson(param); - std::vector trees_json; - size_t t = 0; - for (auto const& tree : trees) { + std::vector trees_json(trees.size()); + + for (size_t t = 0; t < trees.size(); ++t) { + auto const& tree = trees[t]; Json tree_json{Object()}; tree->SaveModel(&tree_json); - // The field is not used in XGBoost, but might be useful for external project. - tree_json["id"] = Integer(t); - trees_json.emplace_back(tree_json); - t++; + tree_json["id"] = Integer(static_cast(t)); + trees_json[t] = std::move(tree_json); } std::vector tree_info_json(tree_info.size()); @@ -70,9 +71,10 @@ void GBTreeModel::LoadModel(Json const& in) { auto const& trees_json = get(in["trees"]); trees.resize(trees_json.size()); - for (size_t t = 0; t < trees.size(); ++t) { - trees[t].reset( new RegTree() ); - trees[t]->LoadModel(trees_json[t]); + for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT + auto tree_id = get(trees_json[t]["id"]); + trees.at(tree_id).reset(new RegTree()); + trees.at(tree_id)->LoadModel(trees_json[t]); } tree_info.resize(param.num_trees); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 7d473f00c..56e4a95ec 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -148,7 +148,16 @@ TEST(Learner, JsonModelIO) { Json out { Object() }; learner->SaveModel(&out); - learner->LoadModel(out); + dmlc::TemporaryDirectory tmpdir; + + std::ofstream fout (tmpdir.path + "/model.json"); + fout << out; + fout.close(); + + auto loaded_str = common::LoadSequentialFile(tmpdir.path + "/model.json"); + Json loaded = Json::Load(StringView{loaded_str.c_str(), loaded_str.size()}); + + learner->LoadModel(loaded); learner->Configure(); Json new_in { Object() }; diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index a0825b523..b4be33ed3 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -5,6 +5,7 @@ import sys import numpy as np import json import asyncio +from sklearn.datasets import make_classification if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) @@ -36,7 +37,7 @@ def generate_array(): def test_from_dask_dataframe(): - with LocalCluster(n_workers=5) as cluster: + with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() @@ -74,7 +75,7 @@ def test_from_dask_dataframe(): def test_from_dask_array(): - with LocalCluster(n_workers=5, threads_per_worker=5) as cluster: + with LocalCluster(n_workers=kWorkers, threads_per_worker=5) as cluster: with Client(cluster) as client: X, y = generate_array() dtrain = DaskDMatrix(client, X, y) @@ -104,8 +105,28 @@ def test_from_dask_array(): assert np.all(single_node_predt == from_arr.compute()) +def test_dask_predict_shape_infer(): + with LocalCluster(n_workers=kWorkers) as cluster: + with Client(cluster) as client: + X, y = make_classification(n_samples=1000, n_informative=5, + n_classes=3) + X_ = dd.from_array(X, chunksize=100) + y_ = dd.from_array(y, chunksize=100) + dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_) + + model = xgb.dask.train( + client, + {"objective": "multi:softprob", "num_class": 3}, + dtrain=dtrain + ) + + preds = xgb.dask.predict(client, model, dtrain) + assert preds.shape[0] == preds.compute().shape[0] + assert preds.shape[1] == preds.compute().shape[1] + + def test_dask_missing_value_reg(): - with LocalCluster(n_workers=5) as cluster: + with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X_0 = np.ones((20 // 2, kCols)) X_1 = np.zeros((20 // 2, kCols)) @@ -144,19 +165,19 @@ def test_dask_missing_value_cls(): missing=0.0) cls.client = client cls.fit(X, y, eval_set=[(X, y)]) - dd_predt = cls.predict(X).compute() + dd_pred_proba = cls.predict_proba(X).compute() np_X = X.compute() - np_predt = cls.get_booster().predict( + np_pred_proba = cls.get_booster().predict( xgb.DMatrix(np_X, missing=0.0)) - np.testing.assert_allclose(np_predt, dd_predt) + np.testing.assert_allclose(np_pred_proba, dd_pred_proba) cls = xgb.dask.DaskXGBClassifier() assert hasattr(cls, 'missing') def test_dask_regressor(): - with LocalCluster(n_workers=5) as cluster: + with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) @@ -178,7 +199,7 @@ def test_dask_regressor(): def test_dask_classifier(): - with LocalCluster(n_workers=5) as cluster: + with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() y = (y * 10).astype(np.int32) @@ -201,7 +222,18 @@ def test_dask_classifier(): assert len(list(history['validation_0'])) == 1 assert len(history['validation_0']['merror']) == 2 + # Test .predict_proba() + probas = classifier.predict_proba(X) assert classifier.n_classes_ == 10 + assert probas.ndim == 2 + assert probas.shape[0] == kRows + assert probas.shape[1] == 10 + + cls_booster = classifier.get_booster() + single_node_proba = cls_booster.inplace_predict(X.compute()) + + np.testing.assert_allclose(single_node_proba, + probas.compute()) # Test with dataframe. X_d = dd.from_dask_array(X) @@ -218,7 +250,7 @@ def test_dask_classifier(): @pytest.mark.skipif(**tm.no_sklearn()) def test_sklearn_grid_search(): from sklearn.model_selection import GridSearchCV - with LocalCluster(n_workers=4) as cluster: + with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: X, y = generate_array() reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1, @@ -292,7 +324,9 @@ def run_empty_dmatrix_cls(client, parameters): evals=[(dtrain, 'validation')], num_boost_round=2) predictions = xgb.dask.predict(client=client, model=out, - data=dtrain).compute() + data=dtrain) + assert predictions.shape[1] == n_classes + predictions = predictions.compute() _check_outputs(out, predictions) # train has more rows than evals @@ -315,7 +349,7 @@ def run_empty_dmatrix_cls(client, parameters): # environment and Exact doesn't support it. def test_empty_dmatrix_hist(): - with LocalCluster(n_workers=5) as cluster: + with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: parameters = {'tree_method': 'hist'} run_empty_dmatrix_reg(client, parameters) @@ -323,7 +357,7 @@ def test_empty_dmatrix_hist(): def test_empty_dmatrix_approx(): - with LocalCluster(n_workers=5) as cluster: + with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: parameters = {'tree_method': 'approx'} run_empty_dmatrix_reg(client, parameters) @@ -397,7 +431,13 @@ async def run_dask_classifier_asyncio(scheduler_address): assert len(list(history['validation_0'])) == 1 assert len(history['validation_0']['merror']) == 2 + # Test .predict_proba() + probas = await classifier.predict_proba(X) assert classifier.n_classes_ == 10 + assert probas.ndim == 2 + assert probas.shape[0] == kRows + assert probas.shape[1] == 10 + # Test with dataframe. X_d = dd.from_dask_array(X)