[EM] Avoid synchronous calls and unnecessary ATS access. (#10811)
- Pass context into various functions. - Factor out some CUDA algorithms. - Use ATS only for update position.
This commit is contained in:
@@ -60,8 +60,7 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
||||
|
||||
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
|
||||
|
||||
evaluator.Reset(cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, false,
|
||||
ctx.Device());
|
||||
evaluator.Reset(&ctx, cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, false);
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||
|
||||
ASSERT_EQ(result.thresh, 1);
|
||||
@@ -104,7 +103,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false, ctx.Device());
|
||||
evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false);
|
||||
|
||||
{
|
||||
// -1.0s go right
|
||||
@@ -217,7 +216,7 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false, ctx.Device());
|
||||
evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false);
|
||||
|
||||
{
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||
@@ -277,10 +276,8 @@ TEST(GpuHist, PartitionTwoNodes) {
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()),
|
||||
ctx.Device()};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false,
|
||||
ctx.Device());
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
|
||||
evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false);
|
||||
|
||||
{
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||
@@ -336,10 +333,8 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()),
|
||||
ctx.Device()};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false,
|
||||
ctx.Device());
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
|
||||
evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false);
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
@@ -522,7 +517,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||
cuts_.cut_values_.SetDevice(ctx.Device());
|
||||
cuts_.min_vals_.SetDevice(ctx.Device());
|
||||
|
||||
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, false, ctx.Device());
|
||||
evaluator.Reset(&ctx, cuts_, dh::ToSpan(ft), info_.num_col_, param_, false);
|
||||
|
||||
// Convert the sample histogram to fixed point
|
||||
auto quantiser = DummyRoundingFactor(&ctx);
|
||||
@@ -586,7 +581,7 @@ void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) {
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true, ctx.Device());
|
||||
evaluator.Reset(&ctx, cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true);
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
|
||||
Reference in New Issue
Block a user