Fix GPU L1 error. (#8749)
This commit is contained in:
parent
16ef016ba7
commit
0e61ba57d6
@ -36,7 +36,6 @@ try:
|
|||||||
|
|
||||||
PANDAS_INSTALLED = True
|
PANDAS_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
MultiIndex = object
|
MultiIndex = object
|
||||||
DataFrame = object
|
DataFrame = object
|
||||||
Series = object
|
Series = object
|
||||||
@ -161,6 +160,7 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
|
|||||||
# `importlib.utils`, except it's unclear from its document on how to use it. This one
|
# `importlib.utils`, except it's unclear from its document on how to use it. This one
|
||||||
# seems to be easy to understand and works out of box.
|
# seems to be easy to understand and works out of box.
|
||||||
|
|
||||||
|
|
||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
|
||||||
|
|||||||
@ -188,7 +188,8 @@ struct GPUHistMakerDevice {
|
|||||||
common::Span<GradientPair> gpair;
|
common::Span<GradientPair> gpair;
|
||||||
|
|
||||||
dh::device_vector<int> monotone_constraints;
|
dh::device_vector<int> monotone_constraints;
|
||||||
dh::device_vector<float> update_predictions;
|
// node idx for each sample
|
||||||
|
dh::device_vector<bst_node_t> positions;
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
|
|
||||||
@ -423,7 +424,7 @@ struct GPUHistMakerDevice {
|
|||||||
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
||||||
}
|
}
|
||||||
p_out_position->Resize(0);
|
p_out_position->Resize(0);
|
||||||
update_predictions.clear();
|
positions.clear();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -458,8 +459,6 @@ struct GPUHistMakerDevice {
|
|||||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||||
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||||
auto d_gpair = this->gpair;
|
auto d_gpair = this->gpair;
|
||||||
update_predictions.resize(row_partitioner->GetRows().size());
|
|
||||||
auto d_update_predictions = dh::ToSpan(update_predictions);
|
|
||||||
p_out_position->SetDevice(ctx_->gpu_id);
|
p_out_position->SetDevice(ctx_->gpu_id);
|
||||||
p_out_position->Resize(row_partitioner->GetRows().size());
|
p_out_position->Resize(row_partitioner->GetRows().size());
|
||||||
|
|
||||||
@ -494,32 +493,48 @@ struct GPUHistMakerDevice {
|
|||||||
node = d_nodes[position];
|
node = d_nodes[position];
|
||||||
}
|
}
|
||||||
|
|
||||||
d_update_predictions[row_id] = node.LeafValue();
|
|
||||||
return position;
|
return position;
|
||||||
}; // NOLINT
|
}; // NOLINT
|
||||||
|
|
||||||
auto d_out_position = p_out_position->DeviceSpan();
|
auto d_out_position = p_out_position->DeviceSpan();
|
||||||
row_partitioner->FinalisePosition(d_out_position, new_position_op);
|
row_partitioner->FinalisePosition(d_out_position, new_position_op);
|
||||||
|
|
||||||
|
auto s_position = p_out_position->ConstDeviceSpan();
|
||||||
|
positions.resize(s_position.size());
|
||||||
|
dh::safe_cuda(cudaMemcpyAsync(positions.data().get(), s_position.data(),
|
||||||
|
s_position.size_bytes(), cudaMemcpyDeviceToDevice,
|
||||||
|
ctx_->CUDACtx()->Stream()));
|
||||||
|
|
||||||
dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) {
|
dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) {
|
||||||
bst_node_t position = d_out_position[idx];
|
bst_node_t position = d_out_position[idx];
|
||||||
d_update_predictions[idx] = d_nodes[position].LeafValue();
|
|
||||||
bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f;
|
bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f;
|
||||||
d_out_position[idx] = is_row_sampled ? ~position : position;
|
d_out_position[idx] = is_row_sampled ? ~position : position;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
|
bool UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
|
||||||
if (update_predictions.empty()) {
|
if (positions.empty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(p_tree);
|
CHECK(p_tree);
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
|
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
|
||||||
auto d_update_predictions = dh::ToSpan(update_predictions);
|
|
||||||
CHECK_EQ(out_preds_d.Size(), d_update_predictions.size());
|
auto d_position = dh::ToSpan(positions);
|
||||||
dh::LaunchN(out_preds_d.Size(), [=] XGBOOST_DEVICE(size_t idx) mutable {
|
CHECK_EQ(out_preds_d.Size(), d_position.size());
|
||||||
out_preds_d(idx) += d_update_predictions[idx];
|
|
||||||
|
auto const& h_nodes = p_tree->GetNodes();
|
||||||
|
dh::caching_device_vector<RegTree::Node> nodes(h_nodes.size());
|
||||||
|
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
|
||||||
|
h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice,
|
||||||
|
ctx_->CUDACtx()->Stream()));
|
||||||
|
auto d_nodes = dh::ToSpan(nodes);
|
||||||
|
dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(),
|
||||||
|
[=] XGBOOST_DEVICE(std::size_t idx) mutable {
|
||||||
|
bst_node_t nidx = d_position[idx];
|
||||||
|
auto weight = d_nodes[nidx].LeafValue();
|
||||||
|
out_preds_d(idx) += weight;
|
||||||
});
|
});
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -862,6 +877,7 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||||
|
|
||||||
char const* Name() const override { return "grow_gpu_hist"; }
|
char const* Name() const override { return "grow_gpu_hist"; }
|
||||||
|
bool HasNodePosition() const override { return true; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool initialised_{false};
|
bool initialised_{false};
|
||||||
|
|||||||
24
tests/cpp/tree/test_node_partition.cc
Normal file
24
tests/cpp/tree/test_node_partition.cc
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/task.h>
|
||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
TEST(Updater, HasNodePosition) {
|
||||||
|
Context ctx;
|
||||||
|
ObjInfo task{ObjInfo::kRegression, true, true};
|
||||||
|
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create("grow_histmaker", &ctx, task)};
|
||||||
|
ASSERT_TRUE(up->HasNodePosition());
|
||||||
|
|
||||||
|
up.reset(TreeUpdater::Create("grow_quantile_histmaker", &ctx, task));
|
||||||
|
ASSERT_TRUE(up->HasNodePosition());
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
ctx.gpu_id = 0;
|
||||||
|
up.reset(TreeUpdater::Create("grow_gpu_hist", &ctx, task));
|
||||||
|
ASSERT_TRUE(up->HasNodePosition());
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
@ -337,13 +337,21 @@ class TestGPUPredict:
|
|||||||
@given(predict_parameter_strategy, tm.dataset_strategy)
|
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||||
def test_predict_leaf_gbtree(self, param, dataset):
|
def test_predict_leaf_gbtree(self, param, dataset):
|
||||||
|
# Unsupported for random forest
|
||||||
|
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
|
||||||
|
return
|
||||||
|
|
||||||
param['booster'] = 'gbtree'
|
param['booster'] = 'gbtree'
|
||||||
param['tree_method'] = 'gpu_hist'
|
param['tree_method'] = 'gpu_hist'
|
||||||
self.run_predict_leaf_booster(param, 10, dataset)
|
self.run_predict_leaf_booster(param, 10, dataset)
|
||||||
|
|
||||||
@given(predict_parameter_strategy, tm.dataset_strategy)
|
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||||
def test_predict_leaf_dart(self, param, dataset):
|
def test_predict_leaf_dart(self, param: dict, dataset: tm.TestDataset) -> None:
|
||||||
|
# Unsupported for random forest
|
||||||
|
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
|
||||||
|
return
|
||||||
|
|
||||||
param['booster'] = 'dart'
|
param['booster'] = 'dart'
|
||||||
param['tree_method'] = 'gpu_hist'
|
param['tree_method'] = 'gpu_hist'
|
||||||
self.run_predict_leaf_booster(param, 10, dataset)
|
self.run_predict_leaf_booster(param, 10, dataset)
|
||||||
|
|||||||
@ -442,6 +442,22 @@ class TestTreeMethod:
|
|||||||
config_0 = json.loads(booster_0.save_config())
|
config_0 = json.loads(booster_0.save_config())
|
||||||
np.testing.assert_allclose(get_score(config_0), get_score(config_1) + 1)
|
np.testing.assert_allclose(get_score(config_0), get_score(config_1) + 1)
|
||||||
|
|
||||||
|
evals_result: Dict[str, Dict[str, list]] = {}
|
||||||
|
xgb.train(
|
||||||
|
{
|
||||||
|
"tree_method": tree_method,
|
||||||
|
"objective": "reg:absoluteerror",
|
||||||
|
"subsample": 0.8
|
||||||
|
},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=10,
|
||||||
|
evals=[(Xy, "Train")],
|
||||||
|
evals_result=evals_result,
|
||||||
|
)
|
||||||
|
mae = evals_result["Train"]["mae"]
|
||||||
|
assert mae[-1] < 20.0
|
||||||
|
assert tm.non_increasing(mae)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tree_method,weighted", [
|
"tree_method,weighted", [
|
||||||
|
|||||||
@ -215,7 +215,6 @@ MultiClfData = namedtuple("MultiClfData", ("multi_clf_df_train", "multi_clf_df_t
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def multi_clf_data(spark: SparkSession) -> Generator[MultiClfData, None, None]:
|
def multi_clf_data(spark: SparkSession) -> Generator[MultiClfData, None, None]:
|
||||||
|
|
||||||
X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]])
|
X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]])
|
||||||
y = np.array([0, 0, 1, 2])
|
y = np.array([0, 0, 1, 2])
|
||||||
cls1 = xgb.XGBClassifier()
|
cls1 = xgb.XGBClassifier()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user