[backport] Fix GPU L1 error. (#8749) (#8770)

* [backport] Fix GPU L1 error. (#8749)

* Fix backport.
This commit is contained in:
Jiaming Yuan 2023-02-09 20:16:39 +08:00 committed by GitHub
parent df984f9c43
commit 60303db2ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 13 deletions

View File

@ -36,7 +36,6 @@ try:
PANDAS_INSTALLED = True PANDAS_INSTALLED = True
except ImportError: except ImportError:
MultiIndex = object MultiIndex = object
DataFrame = object DataFrame = object
Series = object Series = object
@ -161,6 +160,7 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
# `importlib.utils`, except it's unclear from its document on how to use it. This one # `importlib.utils`, except it's unclear from its document on how to use it. This one
# seems to be easy to understand and works out of box. # seems to be easy to understand and works out of box.
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this

View File

@ -188,7 +188,8 @@ struct GPUHistMakerDevice {
common::Span<GradientPair> gpair; common::Span<GradientPair> gpair;
dh::device_vector<int> monotone_constraints; dh::device_vector<int> monotone_constraints;
dh::device_vector<float> update_predictions; // node idx for each sample
dh::device_vector<bst_node_t> positions;
TrainParam param; TrainParam param;
@ -426,7 +427,7 @@ struct GPUHistMakerDevice {
LOG(FATAL) << "Current objective function can not be used with external memory."; LOG(FATAL) << "Current objective function can not be used with external memory.";
} }
p_out_position->Resize(0); p_out_position->Resize(0);
update_predictions.clear(); positions.clear();
return; return;
} }
@ -461,8 +462,6 @@ struct GPUHistMakerDevice {
HostDeviceVector<bst_node_t>* p_out_position) { HostDeviceVector<bst_node_t>* p_out_position) {
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
auto d_gpair = this->gpair; auto d_gpair = this->gpair;
update_predictions.resize(row_partitioner->GetRows().size());
auto d_update_predictions = dh::ToSpan(update_predictions);
p_out_position->SetDevice(ctx_->gpu_id); p_out_position->SetDevice(ctx_->gpu_id);
p_out_position->Resize(row_partitioner->GetRows().size()); p_out_position->Resize(row_partitioner->GetRows().size());
@ -497,32 +496,45 @@ struct GPUHistMakerDevice {
node = d_nodes[position]; node = d_nodes[position];
} }
d_update_predictions[row_id] = node.LeafValue();
return position; return position;
}; // NOLINT }; // NOLINT
auto d_out_position = p_out_position->DeviceSpan(); auto d_out_position = p_out_position->DeviceSpan();
row_partitioner->FinalisePosition(d_out_position, new_position_op); row_partitioner->FinalisePosition(d_out_position, new_position_op);
auto s_position = p_out_position->ConstDeviceSpan();
positions.resize(s_position.size());
dh::safe_cuda(cudaMemcpyAsync(positions.data().get(), s_position.data(),
s_position.size_bytes(), cudaMemcpyDeviceToDevice));
dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) { dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) {
bst_node_t position = d_out_position[idx]; bst_node_t position = d_out_position[idx];
d_update_predictions[idx] = d_nodes[position].LeafValue();
bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f; bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f;
d_out_position[idx] = is_row_sampled ? ~position : position; d_out_position[idx] = is_row_sampled ? ~position : position;
}); });
} }
bool UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) { bool UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
if (update_predictions.empty()) { if (positions.empty()) {
return false; return false;
} }
CHECK(p_tree); CHECK(p_tree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id); CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
auto d_update_predictions = dh::ToSpan(update_predictions);
CHECK_EQ(out_preds_d.Size(), d_update_predictions.size()); auto d_position = dh::ToSpan(positions);
dh::LaunchN(out_preds_d.Size(), [=] XGBOOST_DEVICE(size_t idx) mutable { CHECK_EQ(out_preds_d.Size(), d_position.size());
out_preds_d(idx) += d_update_predictions[idx];
auto const& h_nodes = p_tree->GetNodes();
dh::caching_device_vector<RegTree::Node> nodes(h_nodes.size());
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice));
auto d_nodes = dh::ToSpan(nodes);
dh::LaunchN(d_position.size(), [=] XGBOOST_DEVICE(std::size_t idx) mutable {
bst_node_t nidx = d_position[idx];
auto weight = d_nodes[nidx].LeafValue();
out_preds_d(idx) += weight;
}); });
return true; return true;
} }
@ -865,6 +877,7 @@ class GPUHistMaker : public TreeUpdater {
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
char const* Name() const override { return "grow_gpu_hist"; } char const* Name() const override { return "grow_gpu_hist"; }
bool HasNodePosition() const override { return true; }
private: private:
bool initialised_{false}; bool initialised_{false};

View File

@ -0,0 +1,24 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/task.h>
#include <xgboost/tree_updater.h>
namespace xgboost {
TEST(Updater, HasNodePosition) {
Context ctx;
ObjInfo task{ObjInfo::kRegression, true, true};
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, task)};
ASSERT_TRUE(up->HasNodePosition());
up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, task));
ASSERT_TRUE(up->HasNodePosition());
#if defined(XGBOOST_USE_CUDA)
ctx.gpu_id = 0;
up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, task));
ASSERT_TRUE(up->HasNodePosition());
#endif // defined(XGBOOST_USE_CUDA)
}
} // namespace xgboost

View File

@ -338,13 +338,21 @@ class TestGPUPredict:
@given(predict_parameter_strategy, tm.dataset_strategy) @given(predict_parameter_strategy, tm.dataset_strategy)
@settings(deadline=None, max_examples=20, print_blob=True) @settings(deadline=None, max_examples=20, print_blob=True)
def test_predict_leaf_gbtree(self, param, dataset): def test_predict_leaf_gbtree(self, param, dataset):
# Unsupported for random forest
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
return
param['booster'] = 'gbtree' param['booster'] = 'gbtree'
param['tree_method'] = 'gpu_hist' param['tree_method'] = 'gpu_hist'
self.run_predict_leaf_booster(param, 10, dataset) self.run_predict_leaf_booster(param, 10, dataset)
@given(predict_parameter_strategy, tm.dataset_strategy) @given(predict_parameter_strategy, tm.dataset_strategy)
@settings(deadline=None, max_examples=20, print_blob=True) @settings(deadline=None, max_examples=20, print_blob=True)
def test_predict_leaf_dart(self, param, dataset): def test_predict_leaf_dart(self, param: dict, dataset: tm.TestDataset) -> None:
# Unsupported for random forest
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
return
param['booster'] = 'dart' param['booster'] = 'dart'
param['tree_method'] = 'gpu_hist' param['tree_method'] = 'gpu_hist'
self.run_predict_leaf_booster(param, 10, dataset) self.run_predict_leaf_booster(param, 10, dataset)

View File

@ -458,6 +458,22 @@ class TestTreeMethod:
config_0 = json.loads(booster_0.save_config()) config_0 = json.loads(booster_0.save_config())
np.testing.assert_allclose(get_score(config_0), get_score(config_1) + 1) np.testing.assert_allclose(get_score(config_0), get_score(config_1) + 1)
evals_result: Dict[str, Dict[str, list]] = {}
xgb.train(
{
"tree_method": tree_method,
"objective": "reg:absoluteerror",
"subsample": 0.8
},
Xy,
num_boost_round=10,
evals=[(Xy, "Train")],
evals_result=evals_result,
)
mae = evals_result["Train"]["mae"]
assert mae[-1] < 20.0
assert tm.non_increasing(mae)
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"tree_method,weighted", [ "tree_method,weighted", [