[EM] Suport quantile objectives for GPU-based external memory. (#10820)
- Improved error message for memory usage. - Support quantile-based objectives for GPU external memory.
This commit is contained in:
parent
de00e07087
commit
96bbf80457
@ -5,9 +5,11 @@
|
||||
|
||||
#include <dmlc/thread_local.h> // for ThreadLocalStore
|
||||
|
||||
#include <cmath> // for pow
|
||||
#include <cstdint> // for uint8_t
|
||||
#include <cstdio> // for snprintf, size_t
|
||||
#include <string> // for string
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "./random.h" // for GlobalRandomEngine, GlobalRandom
|
||||
|
||||
@ -54,4 +56,20 @@ void EscapeU8(std::string const &string, std::string *p_buffer) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string HumanMemUnit(std::size_t n_bytes) {
|
||||
auto n_bytes_f64 = static_cast<double>(n_bytes);
|
||||
double constexpr k1024 = 1024.0;
|
||||
using P = std::pair<std::int32_t, StringView>;
|
||||
std::stringstream ss;
|
||||
for (auto pu : {P{3, "GB"}, P{2, "MB"}, P{1, "KB"}}) {
|
||||
auto const [power, unit] = pu; // NOLINT
|
||||
if (n_bytes_f64 >= (std::pow(k1024, power))) {
|
||||
ss << (n_bytes_f64 / std::pow(k1024, power)) << unit;
|
||||
return ss.str();
|
||||
}
|
||||
}
|
||||
ss << n_bytes_f64 << "B";
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
|
||||
@ -188,5 +188,8 @@ template <typename Indexable>
|
||||
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
|
||||
return indptr[group + 1] - 1;
|
||||
}
|
||||
|
||||
// Convert the number of bytes to a human readable unit.
|
||||
std::string HumanMemUnit(std::size_t n_bytes);
|
||||
} // namespace xgboost::common
|
||||
#endif // XGBOOST_COMMON_COMMON_H_
|
||||
|
||||
@ -15,8 +15,7 @@
|
||||
#include <algorithm>
|
||||
#include <cstddef> // for size_t
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh> // for UnitWord
|
||||
#include <tuple>
|
||||
#include <cub/util_type.cuh> // for UnitWord, DoubleBuffer
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
@ -635,7 +634,7 @@ size_t SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy
|
||||
return thrust::make_pair(seg, *(val_first + i));
|
||||
});
|
||||
size_t segments_len = key_segments_last - key_segments_first;
|
||||
thrust::fill(thrust::device, key_segments_out, key_segments_out + segments_len, 0);
|
||||
thrust::fill(exec, key_segments_out, key_segments_out + segments_len, 0);
|
||||
size_t n_inputs = std::distance(val_first, val_last);
|
||||
// Reduce the number of uniques elements per segment, avoid creating an intermediate
|
||||
// array for `reduce_by_key`. It's limited by the types that atomicAdd supports. For
|
||||
@ -736,22 +735,32 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
|
||||
class CUDAStreamView;
|
||||
|
||||
class CUDAEvent {
|
||||
cudaEvent_t event_{nullptr};
|
||||
std::unique_ptr<cudaEvent_t, void (*)(cudaEvent_t *)> event_;
|
||||
|
||||
public:
|
||||
CUDAEvent() { dh::safe_cuda(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); }
|
||||
~CUDAEvent() {
|
||||
if (event_) {
|
||||
dh::safe_cuda(cudaEventDestroy(event_));
|
||||
}
|
||||
CUDAEvent()
|
||||
: event_{[] {
|
||||
auto e = new cudaEvent_t;
|
||||
dh::safe_cuda(cudaEventCreateWithFlags(e, cudaEventDisableTiming));
|
||||
return e;
|
||||
}(),
|
||||
[](cudaEvent_t *e) {
|
||||
if (e) {
|
||||
dh::safe_cuda(cudaEventDestroy(*e));
|
||||
delete e;
|
||||
}
|
||||
}} {}
|
||||
|
||||
inline void Record(CUDAStreamView stream); // NOLINT
|
||||
// Define swap-based ctor to make sure an event is always valid.
|
||||
CUDAEvent(CUDAEvent &&e) : CUDAEvent() { std::swap(this->event_, e.event_); }
|
||||
CUDAEvent &operator=(CUDAEvent &&e) {
|
||||
std::swap(this->event_, e.event_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUDAEvent(CUDAEvent const &that) = delete;
|
||||
CUDAEvent &operator=(CUDAEvent const &that) = delete;
|
||||
|
||||
inline void Record(CUDAStreamView stream); // NOLINT
|
||||
|
||||
operator cudaEvent_t() const { return event_; } // NOLINT
|
||||
operator cudaEvent_t() const { return *event_; } // NOLINT
|
||||
cudaEvent_t const *data() const { return this->event_.get(); } // NOLINT
|
||||
};
|
||||
|
||||
class CUDAStreamView {
|
||||
@ -785,7 +794,7 @@ class CUDAStreamView {
|
||||
};
|
||||
|
||||
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
||||
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
|
||||
dh::safe_cuda(cudaEventRecord(*event_, cudaStream_t{stream}));
|
||||
}
|
||||
|
||||
// Changing this has effect on prediction return, where we need to pass the pointer to
|
||||
|
||||
@ -2,18 +2,20 @@
|
||||
* Copyright 2017-2024, XGBoost contributors
|
||||
*/
|
||||
#include "../collective/communicator-inl.h" // for GetRank
|
||||
#include "common.h" // for HumanMemUnit
|
||||
#include "device_helpers.cuh" // for CurrentDevice
|
||||
#include "device_vector.cuh"
|
||||
|
||||
namespace dh {
|
||||
namespace detail {
|
||||
void ThrowOOMError(std::string const &err, size_t bytes) {
|
||||
void ThrowOOMError(std::string const &err, std::size_t bytes) {
|
||||
auto device = CurrentDevice();
|
||||
auto rank = xgboost::collective::GetRank();
|
||||
using xgboost::common::HumanMemUnit;
|
||||
std::stringstream ss;
|
||||
ss << "Memory allocation error on worker " << rank << ": " << err << "\n"
|
||||
<< "- Free memory: " << dh::AvailableMemory(device) << "\n"
|
||||
<< "- Requested memory: " << bytes << std::endl;
|
||||
<< "- Free memory: " << HumanMemUnit(dh::AvailableMemory(device)) << "\n"
|
||||
<< "- Requested memory: " << HumanMemUnit(bytes) << std::endl;
|
||||
LOG(FATAL) << ss.str();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
@ -31,7 +31,7 @@
|
||||
#include <map> // for map
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
#include "common.h" // for safe_cuda
|
||||
#include "common.h" // for safe_cuda, HumanMemUnit
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace dh {
|
||||
@ -97,12 +97,13 @@ class MemoryLogger {
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device));
|
||||
LOG(CONSOLE) << "======== Device " << current_device << " Memory Allocations: "
|
||||
<< " ========";
|
||||
LOG(CONSOLE) << "Peak memory usage: " << stats_.peak_allocated_bytes / 1048576 << "MiB";
|
||||
LOG(CONSOLE) << "Peak memory usage: "
|
||||
<< xgboost::common::HumanMemUnit(stats_.peak_allocated_bytes);
|
||||
LOG(CONSOLE) << "Number of allocations: " << stats_.num_allocations;
|
||||
}
|
||||
};
|
||||
|
||||
void ThrowOOMError(std::string const &err, size_t bytes);
|
||||
void ThrowOOMError(std::string const &err, std::size_t bytes);
|
||||
} // namespace detail
|
||||
|
||||
inline detail::MemoryLogger &GlobalMemoryLogger() {
|
||||
|
||||
@ -218,10 +218,6 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
|
||||
model_.learner_model_param->OutputLength());
|
||||
CHECK_NE(n_groups, 0);
|
||||
|
||||
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf() && this->ctx_->IsCUDA()) {
|
||||
LOG(FATAL) << "Current objective doesn't support external memory.";
|
||||
}
|
||||
|
||||
// The node position for each row, 1 HDV for each tree in the forest. Note that the
|
||||
// position is negated if the row is sampled out.
|
||||
std::vector<HostDeviceVector<bst_node_t>> node_position;
|
||||
|
||||
@ -148,9 +148,10 @@ class CommonRowPartitioner {
|
||||
template <typename ExpandEntry, typename GHistIndexMatrixT>
|
||||
static void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
|
||||
GHistIndexMatrixT const& gmat,
|
||||
std::vector<int32_t>* split_conditions) {
|
||||
std::vector<int32_t>* p_split_conditions) {
|
||||
auto const& ptrs = gmat.cut.Ptrs();
|
||||
auto const& vals = gmat.cut.Values();
|
||||
auto& split_conditions = *p_split_conditions;
|
||||
|
||||
for (std::size_t i = 0; i < nodes.size(); ++i) {
|
||||
bst_node_t const nidx = nodes[i].nid;
|
||||
@ -167,7 +168,7 @@ class CommonRowPartitioner {
|
||||
split_cond = static_cast<bst_bin_t>(bound);
|
||||
}
|
||||
}
|
||||
(*split_conditions)[i] = split_cond;
|
||||
split_conditions[i] = split_cond;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -520,12 +520,11 @@ struct GPUHistMakerDevice {
|
||||
// prediction cache
|
||||
void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
|
||||
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
||||
}
|
||||
|
||||
monitor.Start(__func__);
|
||||
if (static_cast<std::size_t>(p_fmat->NumBatches() + 1) != this->batch_ptr_.size()) {
|
||||
if (task.UpdateTreeLeaf()) {
|
||||
LOG(FATAL) << "Current objective function can not be used with concatenated pages.";
|
||||
}
|
||||
// External memory with concatenation. Not supported.
|
||||
p_out_position->Resize(0);
|
||||
positions_.clear();
|
||||
|
||||
@ -3,14 +3,21 @@
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/sort.h> // for sort
|
||||
#include <thrust/unique.h> // for unique
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/tree_model.h> // for RegTree
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint32_t
|
||||
#include <vector> // for vector
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint32_t
|
||||
#include <iterator> // for distance
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../../src/data/ellpack_page.cuh"
|
||||
#include "../../../../src/tree/gpu_hist/expand_entry.cuh" // for GPUExpandEntry
|
||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||
#include "../../helpers.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "../../../../src/tree/param.h" // for TrainParam
|
||||
#include "../../helpers.h" // for RandomDataGenerator
|
||||
|
||||
namespace xgboost::tree {
|
||||
void TestUpdatePositionBatch() {
|
||||
@ -91,4 +98,83 @@ TEST(GpuHist, SortPositionBatch) {
|
||||
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 6}});
|
||||
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{3, 6}, {0, 2}});
|
||||
}
|
||||
|
||||
namespace {
|
||||
void GetSplit(RegTree* tree, float split_value, std::vector<GPUExpandEntry>* candidates) {
|
||||
CHECK(!tree->IsMultiTarget());
|
||||
tree->ExpandNode(
|
||||
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
|
||||
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f,
|
||||
/*right_sum=*/0.0f);
|
||||
candidates->front().nid = 0;
|
||||
candidates->front().depth = 0;
|
||||
candidates->front().split.fvalue = split_value;
|
||||
candidates->front().split.findex = 0;
|
||||
}
|
||||
|
||||
void TestExternalMemory() {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
|
||||
bst_bin_t max_bin = 32;
|
||||
auto p_fmat =
|
||||
RandomDataGenerator{256, 16, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||
|
||||
std::vector<std::unique_ptr<RowPartitioner>> partitioners;
|
||||
RegTree tree;
|
||||
std::vector<GPUExpandEntry> candidates(1);
|
||||
|
||||
auto param = BatchParam{max_bin, TrainParam::DftSparseThreshold()};
|
||||
float split_value{0.0f};
|
||||
bst_feature_t const split_ind = 0;
|
||||
dh::device_vector<bst_node_t> position(p_fmat->Info().num_row_, 0);
|
||||
|
||||
auto encode_op = [=] __device__(bst_idx_t, bst_node_t nidx) {
|
||||
return nidx;
|
||||
}; // NOLINT
|
||||
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
if (partitioners.empty()) {
|
||||
auto ptr = page.Impl()->Cuts().Ptrs()[split_ind + 1];
|
||||
split_value = page.Impl()->Cuts().Values().at(ptr / 2);
|
||||
GetSplit(&tree, split_value, &candidates);
|
||||
}
|
||||
|
||||
partitioners.emplace_back(std::make_unique<RowPartitioner>());
|
||||
partitioners.back()->Reset(&ctx, page.Size(), page.BaseRowId());
|
||||
std::vector<RegTree::Node> splits{tree[0]};
|
||||
auto acc = page.Impl()->GetDeviceAccessor(&ctx);
|
||||
partitioners.back()->UpdatePositionBatch(
|
||||
{0}, {1}, {2}, splits,
|
||||
[=] __device__(bst_idx_t ridx, std::int32_t nidx_in_batch, RegTree::Node const& node) {
|
||||
auto fvalue = acc.GetFvalue(ridx, node.SplitIndex());
|
||||
return fvalue <= node.SplitCond();
|
||||
});
|
||||
partitioners.back()->FinalisePosition(
|
||||
&ctx, dh::ToSpan(position).subspan(page.BaseRowId(), page.Size()), page.BaseRowId(),
|
||||
encode_op);
|
||||
}
|
||||
|
||||
bst_idx_t n_left{0};
|
||||
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
|
||||
auto batch = page.GetView();
|
||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||
if (batch[i][split_ind].fvalue < split_value) {
|
||||
n_left++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RegTree::Node node = tree[RegTree::kRoot];
|
||||
auto n_left_pos =
|
||||
thrust::count_if(position.cbegin(), position.cend(),
|
||||
[=] XGBOOST_DEVICE(bst_node_t v) { return v == node.LeftChild(); });
|
||||
ASSERT_EQ(n_left, n_left_pos);
|
||||
thrust::sort(position.begin(), position.end());
|
||||
auto end_it = thrust::unique(position.begin(), position.end());
|
||||
ASSERT_EQ(std::distance(position.begin(), end_it), 2);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(RowPartitioner, LeafPartitionExternalMemory) { TestExternalMemory(); }
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -70,3 +70,28 @@ def test_extmem_qdm(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int, on_host: bool
|
||||
) -> None:
|
||||
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cuda", on_host)
|
||||
|
||||
|
||||
@given(
|
||||
strategies.integers(1, 64),
|
||||
strategies.integers(1, 8),
|
||||
strategies.integers(1, 4),
|
||||
)
|
||||
@settings(deadline=None, max_examples=10, print_blob=True)
|
||||
def test_quantile_objective(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int
|
||||
) -> None:
|
||||
check_quantile_loss_extmem(
|
||||
n_samples_per_batch,
|
||||
n_features,
|
||||
n_batches,
|
||||
"hist",
|
||||
"cuda",
|
||||
)
|
||||
check_quantile_loss_extmem(
|
||||
n_samples_per_batch,
|
||||
n_features,
|
||||
n_batches,
|
||||
"approx",
|
||||
"cuda",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user