From 0e61ba57d629319e3e2820b8ccbc8ec2174fcc0b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 4 Feb 2023 03:02:00 +0800 Subject: [PATCH] Fix GPU L1 error. (#8749) --- python-package/xgboost/compat.py | 2 +- src/tree/updater_gpu_hist.cu | 40 +++++++++++++------ tests/cpp/tree/test_node_partition.cc | 24 +++++++++++ tests/python-gpu/test_gpu_prediction.py | 10 ++++- tests/python/test_updaters.py | 16 ++++++++ .../test_with_spark/test_spark_local.py | 1 - 6 files changed, 78 insertions(+), 15 deletions(-) create mode 100644 tests/cpp/tree/test_node_partition.cc diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index fab734a01..3be023abf 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -36,7 +36,6 @@ try: PANDAS_INSTALLED = True except ImportError: - MultiIndex = object DataFrame = object Series = object @@ -161,6 +160,7 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem # `importlib.utils`, except it's unclear from its document on how to use it. This one # seems to be easy to understand and works out of box. + # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 853716726..87d50699a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -188,7 +188,8 @@ struct GPUHistMakerDevice { common::Span gpair; dh::device_vector monotone_constraints; - dh::device_vector update_predictions; + // node idx for each sample + dh::device_vector positions; TrainParam param; @@ -423,7 +424,7 @@ struct GPUHistMakerDevice { LOG(FATAL) << "Current objective function can not be used with external memory."; } p_out_position->Resize(0); - update_predictions.clear(); + positions.clear(); return; } @@ -458,8 +459,6 @@ struct GPUHistMakerDevice { HostDeviceVector* p_out_position) { auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto d_gpair = this->gpair; - update_predictions.resize(row_partitioner->GetRows().size()); - auto d_update_predictions = dh::ToSpan(update_predictions); p_out_position->SetDevice(ctx_->gpu_id); p_out_position->Resize(row_partitioner->GetRows().size()); @@ -494,33 +493,49 @@ struct GPUHistMakerDevice { node = d_nodes[position]; } - d_update_predictions[row_id] = node.LeafValue(); return position; }; // NOLINT auto d_out_position = p_out_position->DeviceSpan(); row_partitioner->FinalisePosition(d_out_position, new_position_op); + auto s_position = p_out_position->ConstDeviceSpan(); + positions.resize(s_position.size()); + dh::safe_cuda(cudaMemcpyAsync(positions.data().get(), s_position.data(), + s_position.size_bytes(), cudaMemcpyDeviceToDevice, + ctx_->CUDACtx()->Stream())); + dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) { bst_node_t position = d_out_position[idx]; - d_update_predictions[idx] = d_nodes[position].LeafValue(); bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f; d_out_position[idx] = is_row_sampled ? ~position : position; }); } bool UpdatePredictionCache(linalg::VectorView out_preds_d, RegTree const* p_tree) { - if (update_predictions.empty()) { + if (positions.empty()) { return false; } + CHECK(p_tree); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id); - auto d_update_predictions = dh::ToSpan(update_predictions); - CHECK_EQ(out_preds_d.Size(), d_update_predictions.size()); - dh::LaunchN(out_preds_d.Size(), [=] XGBOOST_DEVICE(size_t idx) mutable { - out_preds_d(idx) += d_update_predictions[idx]; - }); + + auto d_position = dh::ToSpan(positions); + CHECK_EQ(out_preds_d.Size(), d_position.size()); + + auto const& h_nodes = p_tree->GetNodes(); + dh::caching_device_vector nodes(h_nodes.size()); + dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(), + h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice, + ctx_->CUDACtx()->Stream())); + auto d_nodes = dh::ToSpan(nodes); + dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(), + [=] XGBOOST_DEVICE(std::size_t idx) mutable { + bst_node_t nidx = d_position[idx]; + auto weight = d_nodes[nidx].LeafValue(); + out_preds_d(idx) += weight; + }); return true; } @@ -862,6 +877,7 @@ class GPUHistMaker : public TreeUpdater { std::unique_ptr> maker; // NOLINT char const* Name() const override { return "grow_gpu_hist"; } + bool HasNodePosition() const override { return true; } private: bool initialised_{false}; diff --git a/tests/cpp/tree/test_node_partition.cc b/tests/cpp/tree/test_node_partition.cc new file mode 100644 index 000000000..883c8e68f --- /dev/null +++ b/tests/cpp/tree/test_node_partition.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2023 by XGBoost contributors + */ +#include +#include +#include + +namespace xgboost { +TEST(Updater, HasNodePosition) { + Context ctx; + ObjInfo task{ObjInfo::kRegression, true, true}; + std::unique_ptr up{TreeUpdater::Create("grow_histmaker", &ctx, task)}; + ASSERT_TRUE(up->HasNodePosition()); + + up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, task)); + ASSERT_TRUE(up->HasNodePosition()); + +#if defined(XGBOOST_USE_CUDA) + ctx.gpu_id = 0; + up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, task)); + ASSERT_TRUE(up->HasNodePosition()); +#endif // defined(XGBOOST_USE_CUDA) +} +} // namespace xgboost diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index 63154e775..3f8b4557f 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -337,13 +337,21 @@ class TestGPUPredict: @given(predict_parameter_strategy, tm.dataset_strategy) @settings(deadline=None, max_examples=20, print_blob=True) def test_predict_leaf_gbtree(self, param, dataset): + # Unsupported for random forest + if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"): + return + param['booster'] = 'gbtree' param['tree_method'] = 'gpu_hist' self.run_predict_leaf_booster(param, 10, dataset) @given(predict_parameter_strategy, tm.dataset_strategy) @settings(deadline=None, max_examples=20, print_blob=True) - def test_predict_leaf_dart(self, param, dataset): + def test_predict_leaf_dart(self, param: dict, dataset: tm.TestDataset) -> None: + # Unsupported for random forest + if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"): + return + param['booster'] = 'dart' param['tree_method'] = 'gpu_hist' self.run_predict_leaf_booster(param, 10, dataset) diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 98a58186e..130af619c 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -442,6 +442,22 @@ class TestTreeMethod: config_0 = json.loads(booster_0.save_config()) np.testing.assert_allclose(get_score(config_0), get_score(config_1) + 1) + evals_result: Dict[str, Dict[str, list]] = {} + xgb.train( + { + "tree_method": tree_method, + "objective": "reg:absoluteerror", + "subsample": 0.8 + }, + Xy, + num_boost_round=10, + evals=[(Xy, "Train")], + evals_result=evals_result, + ) + mae = evals_result["Train"]["mae"] + assert mae[-1] < 20.0 + assert tm.non_increasing(mae) + @pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.parametrize( "tree_method,weighted", [ diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 52b7fadb7..27f1ef06f 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -215,7 +215,6 @@ MultiClfData = namedtuple("MultiClfData", ("multi_clf_df_train", "multi_clf_df_t @pytest.fixture def multi_clf_data(spark: SparkSession) -> Generator[MultiClfData, None, None]: - X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]]) y = np.array([0, 0, 1, 2]) cls1 = xgb.XGBClassifier()