* [backport] Fix GPU L1 error. (#8749) * Fix backport.
This commit is contained in:
parent
df984f9c43
commit
60303db2ee
@ -36,7 +36,6 @@ try:
|
||||
|
||||
PANDAS_INSTALLED = True
|
||||
except ImportError:
|
||||
|
||||
MultiIndex = object
|
||||
DataFrame = object
|
||||
Series = object
|
||||
@ -161,6 +160,7 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
|
||||
# `importlib.utils`, except it's unclear from its document on how to use it. This one
|
||||
# seems to be easy to understand and works out of box.
|
||||
|
||||
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
|
||||
|
||||
@ -188,7 +188,8 @@ struct GPUHistMakerDevice {
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
dh::device_vector<int> monotone_constraints;
|
||||
dh::device_vector<float> update_predictions;
|
||||
// node idx for each sample
|
||||
dh::device_vector<bst_node_t> positions;
|
||||
|
||||
TrainParam param;
|
||||
|
||||
@ -426,7 +427,7 @@ struct GPUHistMakerDevice {
|
||||
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
||||
}
|
||||
p_out_position->Resize(0);
|
||||
update_predictions.clear();
|
||||
positions.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -461,8 +462,6 @@ struct GPUHistMakerDevice {
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
auto d_gpair = this->gpair;
|
||||
update_predictions.resize(row_partitioner->GetRows().size());
|
||||
auto d_update_predictions = dh::ToSpan(update_predictions);
|
||||
p_out_position->SetDevice(ctx_->gpu_id);
|
||||
p_out_position->Resize(row_partitioner->GetRows().size());
|
||||
|
||||
@ -497,32 +496,45 @@ struct GPUHistMakerDevice {
|
||||
node = d_nodes[position];
|
||||
}
|
||||
|
||||
d_update_predictions[row_id] = node.LeafValue();
|
||||
return position;
|
||||
}; // NOLINT
|
||||
|
||||
auto d_out_position = p_out_position->DeviceSpan();
|
||||
row_partitioner->FinalisePosition(d_out_position, new_position_op);
|
||||
|
||||
auto s_position = p_out_position->ConstDeviceSpan();
|
||||
positions.resize(s_position.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(positions.data().get(), s_position.data(),
|
||||
s_position.size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
|
||||
dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) {
|
||||
bst_node_t position = d_out_position[idx];
|
||||
d_update_predictions[idx] = d_nodes[position].LeafValue();
|
||||
bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f;
|
||||
d_out_position[idx] = is_row_sampled ? ~position : position;
|
||||
});
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
|
||||
if (update_predictions.empty()) {
|
||||
if (positions.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CHECK(p_tree);
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
|
||||
auto d_update_predictions = dh::ToSpan(update_predictions);
|
||||
CHECK_EQ(out_preds_d.Size(), d_update_predictions.size());
|
||||
dh::LaunchN(out_preds_d.Size(), [=] XGBOOST_DEVICE(size_t idx) mutable {
|
||||
out_preds_d(idx) += d_update_predictions[idx];
|
||||
|
||||
auto d_position = dh::ToSpan(positions);
|
||||
CHECK_EQ(out_preds_d.Size(), d_position.size());
|
||||
|
||||
auto const& h_nodes = p_tree->GetNodes();
|
||||
dh::caching_device_vector<RegTree::Node> nodes(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
|
||||
h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice));
|
||||
auto d_nodes = dh::ToSpan(nodes);
|
||||
dh::LaunchN(d_position.size(), [=] XGBOOST_DEVICE(std::size_t idx) mutable {
|
||||
bst_node_t nidx = d_position[idx];
|
||||
auto weight = d_nodes[nidx].LeafValue();
|
||||
out_preds_d(idx) += weight;
|
||||
});
|
||||
return true;
|
||||
}
|
||||
@ -865,6 +877,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
char const* Name() const override { return "grow_gpu_hist"; }
|
||||
bool HasNodePosition() const override { return true; }
|
||||
|
||||
private:
|
||||
bool initialised_{false};
|
||||
|
||||
24
tests/cpp/tree/test_node_partition.cc
Normal file
24
tests/cpp/tree/test_node_partition.cc
Normal file
@ -0,0 +1,24 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/task.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
namespace xgboost {
|
||||
TEST(Updater, HasNodePosition) {
|
||||
Context ctx;
|
||||
ObjInfo task{ObjInfo::kRegression, true, true};
|
||||
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, task)};
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
|
||||
up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, task));
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
ctx.gpu_id = 0;
|
||||
up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, task));
|
||||
ASSERT_TRUE(up->HasNodePosition());
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -338,13 +338,21 @@ class TestGPUPredict:
|
||||
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||
def test_predict_leaf_gbtree(self, param, dataset):
|
||||
# Unsupported for random forest
|
||||
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
|
||||
return
|
||||
|
||||
param['booster'] = 'gbtree'
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
self.run_predict_leaf_booster(param, 10, dataset)
|
||||
|
||||
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||
def test_predict_leaf_dart(self, param, dataset):
|
||||
def test_predict_leaf_dart(self, param: dict, dataset: tm.TestDataset) -> None:
|
||||
# Unsupported for random forest
|
||||
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
|
||||
return
|
||||
|
||||
param['booster'] = 'dart'
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
self.run_predict_leaf_booster(param, 10, dataset)
|
||||
|
||||
@ -458,6 +458,22 @@ class TestTreeMethod:
|
||||
config_0 = json.loads(booster_0.save_config())
|
||||
np.testing.assert_allclose(get_score(config_0), get_score(config_1) + 1)
|
||||
|
||||
evals_result: Dict[str, Dict[str, list]] = {}
|
||||
xgb.train(
|
||||
{
|
||||
"tree_method": tree_method,
|
||||
"objective": "reg:absoluteerror",
|
||||
"subsample": 0.8
|
||||
},
|
||||
Xy,
|
||||
num_boost_round=10,
|
||||
evals=[(Xy, "Train")],
|
||||
evals_result=evals_result,
|
||||
)
|
||||
mae = evals_result["Train"]["mae"]
|
||||
assert mae[-1] < 20.0
|
||||
assert tm.non_increasing(mae)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
@pytest.mark.parametrize(
|
||||
"tree_method,weighted", [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user