Support optimal partitioning for GPU hist. (#7652)

* Implement `MaxCategory` in quantile.
* Implement partition-based split for GPU evaluation.  Currently, it's based on the existing evaluation function.
* Extract an evaluator from GPU Hist to store the needed states.
* Added some CUDA stream/event utilities.
* Update document with references.
* Fixed a bug in approx evaluator where the number of data points is less than the number of categories.
This commit is contained in:
Jiaming Yuan
2022-02-15 03:03:12 +08:00
committed by GitHub
parent 2369d55e9a
commit 0d0abe1845
26 changed files with 1088 additions and 528 deletions

View File

@@ -159,6 +159,10 @@ class DeviceHistogram {
// Manage memory for a single GPU
template <typename GradientSumT>
struct GPUHistMakerDevice {
private:
GPUHistEvaluator<GradientSumT> evaluator_;
public:
int device_id;
EllpackPageImpl const* page;
common::Span<FeatureType const> feature_types;
@@ -182,7 +186,6 @@ struct GPUHistMakerDevice {
dh::PinnedMemory pinned;
common::Monitor monitor;
TreeEvaluator tree_evaluator;
common::ColumnSampler column_sampler;
FeatureInteractionConstraintDevice interaction_constraints;
@@ -192,24 +195,20 @@ struct GPUHistMakerDevice {
// Storing split categories for last node.
dh::caching_device_vector<uint32_t> node_categories;
GPUHistMakerDevice(int _device_id,
EllpackPageImpl const* _page,
common::Span<FeatureType const> _feature_types,
bst_uint _n_rows,
TrainParam _param,
uint32_t column_sampler_seed,
uint32_t n_features,
GPUHistMakerDevice(int _device_id, EllpackPageImpl const* _page,
common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
BatchParam _batch_param)
: device_id(_device_id),
: evaluator_{_param, n_features, _device_id},
device_id(_device_id),
page(_page),
feature_types{_feature_types},
param(std::move(_param)),
tree_evaluator(param, n_features, _device_id),
column_sampler(column_sampler_seed),
interaction_constraints(param, n_features),
batch_param(std::move(_batch_param)) {
sampler.reset(new GradientBasedSampler(
page, _n_rows, batch_param, param.subsample, param.sampling_method));
sampler.reset(new GradientBasedSampler(page, _n_rows, batch_param, param.subsample,
param.sampling_method));
if (!param.monotone_constraints.empty()) {
// Copy assigning an empty vector causes an exception in MSVC debug builds
monotone_constraints = param.monotone_constraints;
@@ -219,9 +218,8 @@ struct GPUHistMakerDevice {
// Init histogram
hist.Init(device_id, page->Cuts().TotalBins());
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
dh::MaxSharedMemoryOptin(device_id),
sizeof(GradientSumT)));
feature_groups.reset(new FeatureGroups(
page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), sizeof(GradientSumT)));
}
~GPUHistMakerDevice() { // NOLINT
@@ -231,13 +229,17 @@ struct GPUHistMakerDevice {
// Reset values for each update iteration
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns,
ObjInfo task) {
auto const& info = dmat->Info();
this->column_sampler.Init(num_columns, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
tree_evaluator = TreeEvaluator(param, dmat->Info().num_col_, device_id);
this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param,
device_id);
this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
@@ -258,10 +260,8 @@ struct GPUHistMakerDevice {
hist.Reset();
}
DeviceSplitCandidate EvaluateRootSplit(GradientPairPrecise root_sum) {
GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum, float weight, ObjInfo task) {
int nidx = RegTree::kRoot;
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
GPUTrainingParam gpu_param(param);
auto sampled_features = column_sampler.GetFeatureSet(0);
sampled_features->SetDevice(device_id);
@@ -277,32 +277,23 @@ struct GPUHistMakerDevice {
matrix.gidx_fvalue_map,
matrix.min_fvalue,
hist.GetNodeHistogram(nidx)};
auto gain_calc = tree_evaluator.GetEvaluator<GPUTrainingParam>();
EvaluateSingleSplit(dh::ToSpan(splits_out), gain_calc, inputs);
std::vector<DeviceSplitCandidate> result(1);
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
sizeof(DeviceSplitCandidate) * splits_out.size(),
cudaMemcpyDeviceToHost));
return result.front();
auto split = this->evaluator_.EvaluateSingleSplit(inputs, weight, task);
return split;
}
void EvaluateLeftRightSplits(
GPUExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree,
common::Span<GPUExpandEntry> pinned_candidates_out) {
void EvaluateLeftRightSplits(GPUExpandEntry candidate, ObjInfo task, int left_nidx,
int right_nidx, const RegTree& tree,
common::Span<GPUExpandEntry> pinned_candidates_out) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
GPUTrainingParam gpu_param(param);
auto left_sampled_features =
column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
left_sampled_features->SetDevice(device_id);
common::Span<bst_feature_t> left_feature_set =
interaction_constraints.Query(left_sampled_features->DeviceSpan(),
left_nidx);
auto right_sampled_features =
column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx);
auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
right_sampled_features->SetDevice(device_id);
common::Span<bst_feature_t> right_feature_set =
interaction_constraints.Query(right_sampled_features->DeviceSpan(),
left_nidx);
interaction_constraints.Query(right_sampled_features->DeviceSpan(), left_nidx);
auto matrix = page->GetDeviceAccessor(device_id);
EvaluateSplitInputs<GradientSumT> left{left_nidx,
@@ -323,29 +314,11 @@ struct GPUHistMakerDevice {
matrix.gidx_fvalue_map,
matrix.min_fvalue,
hist.GetNodeHistogram(right_nidx)};
auto d_splits_out = dh::ToSpan(splits_out);
EvaluateSplits(d_splits_out, tree_evaluator.GetEvaluator<GPUTrainingParam>(), left, right);
dh::TemporaryArray<GPUExpandEntry> entries(2);
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
auto d_entries = entries.data().get();
dh::LaunchN(2, [=] __device__(size_t idx) {
auto split = d_splits_out[idx];
auto nidx = idx == 0 ? left_nidx : right_nidx;
float base_weight = evaluator.CalcWeight(
nidx, gpu_param, GradStats{split.left_sum + split.right_sum});
float left_weight =
evaluator.CalcWeight(nidx, gpu_param, GradStats{split.left_sum});
float right_weight = evaluator.CalcWeight(
nidx, gpu_param, GradStats{split.right_sum});
d_entries[idx] =
GPUExpandEntry{nidx, candidate.depth + 1, d_splits_out[idx],
base_weight, left_weight, right_weight};
});
dh::safe_cuda(cudaMemcpyAsync(
pinned_candidates_out.data(), entries.data().get(),
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
this->evaluator_.EvaluateSplits(candidate, task, left, right, dh::ToSpan(entries));
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(),
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
}
void BuildHist(int nidx) {
@@ -369,12 +342,10 @@ struct GPUHistMakerDevice {
});
}
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram,
int nidx_subtraction) {
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
// Make sure histograms are already allocated
hist.AllocateHistogram(nidx_subtraction);
return hist.HistogramExists(nidx_histogram) &&
hist.HistogramExists(nidx_parent);
return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
}
void UpdatePosition(int nidx, RegTree* p_tree) {
@@ -503,13 +474,12 @@ struct GPUHistMakerDevice {
cudaMemcpyHostToDevice));
auto d_position = row_partitioner->GetPosition();
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
auto tree_evaluator = evaluator_.GetEvaluator();
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(
int local_idx) mutable {
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(int local_idx) mutable {
int pos = d_position[local_idx];
bst_float weight = evaluator.CalcWeight(
pos, param_d, GradStats{d_node_sum_gradients[pos]});
bst_float weight =
tree_evaluator.CalcWeight(pos, param_d, GradStats{d_node_sum_gradients[pos]});
static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate;
});
@@ -562,7 +532,6 @@ struct GPUHistMakerDevice {
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree;
auto evaluator = tree_evaluator.GetEvaluator();
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
auto base_weight = candidate.base_weight;
auto left_weight = candidate.left_weight * param.learning_rate;
@@ -572,48 +541,50 @@ struct GPUHistMakerDevice {
if (is_cat) {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
if (common::InvalidCat(candidate.split.fvalue)) {
common::InvalidCategory();
std::vector<uint32_t> split_cats;
if (candidate.split.split_cats.Bits().empty()) {
if (common::InvalidCat(candidate.split.fvalue)) {
common::InvalidCategory();
}
auto cat = common::AsCat(candidate.split.fvalue);
split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0);
common::CatBitField cats_bits(split_cats);
cats_bits.Set(cat);
dh::CopyToD(split_cats, &node_categories);
} else {
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
auto max_cat = candidate.split.MaxCat();
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
CHECK_LE(split_cats.size(), h_cats.size());
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
node_categories.resize(candidate.split.split_cats.Bits().size());
dh::safe_cuda(cudaMemcpyAsync(
node_categories.data().get(), candidate.split.split_cats.Data(),
candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice));
}
auto cat = common::AsCat(candidate.split.fvalue);
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cats_bits(split_cats);
cats_bits.Set(cat);
dh::CopyToD(split_cats, &node_categories);
tree.ExpandCategorical(
candidate.nid, candidate.split.findex, split_cats,
candidate.split.dir == kLeftDir, base_weight, left_weight,
right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else {
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
}
evaluator_.ApplyTreeSplit(candidate, p_tree);
// Set up child constraints
auto left_child = tree[candidate.nid].LeftChild();
auto right_child = tree[candidate.nid].RightChild();
node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum;
node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum;
tree_evaluator.AddSplit(candidate.nid, left_child, right_child,
tree[candidate.nid].SplitIndex(), candidate.left_weight,
candidate.right_weight);
node_sum_gradients[tree[candidate.nid].LeftChild()] =
candidate.split.left_sum;
node_sum_gradients[tree[candidate.nid].RightChild()] =
candidate.split.right_sum;
interaction_constraints.Split(
candidate.nid, tree[candidate.nid].SplitIndex(),
tree[candidate.nid].LeftChild(),
interaction_constraints.Split(candidate.nid, tree[candidate.nid].SplitIndex(),
tree[candidate.nid].LeftChild(),
tree[candidate.nid].RightChild());
}
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
GPUExpandEntry InitRoot(RegTree* p_tree, ObjInfo task, dh::AllReducer* reducer) {
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
@@ -634,39 +605,21 @@ struct GPUHistMakerDevice {
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
// Generate first split
auto split = this->EvaluateRootSplit(root_sum);
dh::TemporaryArray<GPUExpandEntry> entries(1);
auto d_entries = entries.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
GPUTrainingParam gpu_param(param);
auto depth = p_tree->GetDepth(kRootNIdx);
dh::LaunchN(1, [=] __device__(size_t idx) {
float left_weight = evaluator.CalcWeight(kRootNIdx, gpu_param,
GradStats{split.left_sum});
float right_weight = evaluator.CalcWeight(
kRootNIdx, gpu_param, GradStats{split.right_sum});
d_entries[0] =
GPUExpandEntry(kRootNIdx, depth, split,
weight, left_weight, right_weight);
});
GPUExpandEntry root_entry;
dh::safe_cuda(cudaMemcpyAsync(
&root_entry, entries.data().get(),
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
auto root_entry = this->EvaluateRootSplit(root_sum, weight, task);
return root_entry;
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
RegTree* p_tree, dh::AllReducer* reducer) {
auto& tree = *p_tree;
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
monitor.Start("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_, task);
monitor.Stop("Reset");
monitor.Start("InitRoot");
driver.Push({ this->InitRoot(p_tree, reducer) });
driver.Push({ this->InitRoot(p_tree, task, reducer) });
monitor.Stop("InitRoot");
auto num_leaves = 1;
@@ -703,8 +656,7 @@ struct GPUHistMakerDevice {
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, left_child_nidx,
right_child_nidx, *p_tree,
this->EvaluateLeftRightSplits(candidate, task, left_child_nidx, right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits");
} else {
@@ -819,14 +771,13 @@ class GPUHistMakerSpecialised {
CHECK(*local_tree == reference_tree);
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
RegTree* p_tree) {
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree) {
monitor_.Start("InitData");
this->InitData(p_fmat);
monitor_.Stop("InitData");
gpair->SetDevice(device_);
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_);
}
bool UpdatePredictionCache(const DMatrix *data,