Optimize GPU evaluation function for categorical data. (#7705)
* Use transform and cache.
This commit is contained in:
parent
18a4af63aa
commit
1d468e20a4
@ -74,12 +74,12 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, random_state=1994, test_size=0.2
|
||||
)
|
||||
# Specify `enable_categorical`.
|
||||
# Specify `enable_categorical` to True.
|
||||
clf = xgb.XGBClassifier(
|
||||
**params,
|
||||
eval_metric="auc",
|
||||
enable_categorical=True,
|
||||
max_cat_to_onehot=1, # We use optimal partitioning exclusively
|
||||
max_cat_to_onehot=1, # We use optimal partitioning exclusively
|
||||
)
|
||||
clf.fit(X_train, y_train, eval_set=[(X_test, y_test), (X_train, y_train)])
|
||||
clf.save_model(os.path.join(output_dir, "categorical.json"))
|
||||
@ -94,13 +94,12 @@ def onehot_encoding_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> Non
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, random_state=42, test_size=0.2
|
||||
)
|
||||
# Specify `enable_categorical`.
|
||||
clf = xgb.XGBClassifier(**params, enable_categorical=False)
|
||||
# Specify `enable_categorical` to False as we are using encoded data.
|
||||
clf = xgb.XGBClassifier(**params, eval_metric="auc", enable_categorical=False)
|
||||
clf.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
eval_set=[(X_test, y_test), (X_train, y_train)],
|
||||
eval_metric="auc",
|
||||
)
|
||||
clf.save_model(os.path.join(output_dir, "one-hot.json"))
|
||||
|
||||
|
||||
@ -51,6 +51,12 @@ class GPUHistEvaluator {
|
||||
dh::CUDAStream copy_stream_;
|
||||
// storage for sorted index of feature histogram, used for sort based splits.
|
||||
dh::device_vector<bst_feature_t> cat_sorted_idx_;
|
||||
// cached input for sorting the histogram, used for sort based splits.
|
||||
using SortPair = thrust::tuple<uint32_t, double>;
|
||||
dh::device_vector<SortPair> sort_input_;
|
||||
// cache for feature index
|
||||
dh::device_vector<bst_feature_t> feature_idx_;
|
||||
// Training param used for evaluation
|
||||
TrainParam param_;
|
||||
// whether the input data requires sort based split, which is more complicated so we try
|
||||
// to avoid it if possible.
|
||||
@ -95,6 +101,13 @@ class GPUHistEvaluator {
|
||||
return dh::ToSpan(cat_sorted_idx_);
|
||||
}
|
||||
|
||||
auto SortInput(EvaluateSplitInputs<GradientSumT> left) {
|
||||
if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) {
|
||||
return dh::ToSpan(sort_input_).first(left.feature_values.size());
|
||||
}
|
||||
return dh::ToSpan(sort_input_);
|
||||
}
|
||||
|
||||
public:
|
||||
GPUHistEvaluator(TrainParam const ¶m, bst_feature_t n_features, int32_t device)
|
||||
: tree_evaluator_{param, n_features, device}, param_{param} {}
|
||||
|
||||
@ -54,6 +54,21 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
|
||||
|
||||
cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
|
||||
sort_input_.resize(cat_sorted_idx_.size());
|
||||
|
||||
/**
|
||||
* cache feature index binary search result
|
||||
*/
|
||||
feature_idx_.resize(cat_sorted_idx_.size());
|
||||
auto d_fidxes = dh::ToSpan(feature_idx_);
|
||||
auto it = thrust::make_counting_iterator(0ul);
|
||||
auto values = cuts.cut_values_.ConstDeviceSpan();
|
||||
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
|
||||
thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(),
|
||||
feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) {
|
||||
auto fidx = dh::SegmentId(ptrs, i);
|
||||
return fidx;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -62,35 +77,55 @@ template <typename GradientSumT>
|
||||
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
|
||||
EvaluateSplitInputs<GradientSumT> const &left, EvaluateSplitInputs<GradientSumT> const &right,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto sorted_idx = this->SortedIdx(left);
|
||||
dh::Iota(sorted_idx);
|
||||
// sort 2 nodes and all the features at the same time, disregarding colmun sampling.
|
||||
thrust::stable_sort(
|
||||
thrust::cuda::par(alloc), dh::tbegin(sorted_idx), dh::tend(sorted_idx),
|
||||
[evaluator, left, right] XGBOOST_DEVICE(size_t l, size_t r) {
|
||||
auto l_is_left = l < left.feature_values.size();
|
||||
auto r_is_left = r < left.feature_values.size();
|
||||
if (l_is_left != r_is_left) {
|
||||
return l_is_left; // not the same node
|
||||
}
|
||||
auto data = this->SortInput(left);
|
||||
auto it = thrust::make_counting_iterator(0u);
|
||||
auto d_feature_idx = dh::ToSpan(feature_idx_);
|
||||
thrust::transform(thrust::cuda::par(alloc), it, it + data.size(), dh::tbegin(data),
|
||||
[=] XGBOOST_DEVICE(uint32_t i) {
|
||||
auto is_left = i < left.feature_values.size();
|
||||
auto const &input = is_left ? left : right;
|
||||
auto j = i - (is_left ? 0 : input.feature_values.size());
|
||||
auto fidx = d_feature_idx[j];
|
||||
if (common::IsCat(input.feature_types, fidx)) {
|
||||
auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[j]);
|
||||
return thrust::make_tuple(i, lw);
|
||||
}
|
||||
return thrust::make_tuple(i, 0.0);
|
||||
});
|
||||
thrust::stable_sort_by_key(thrust::cuda::par(alloc), dh::tbegin(data), dh::tend(data),
|
||||
dh::tbegin(sorted_idx),
|
||||
[=] XGBOOST_DEVICE(SortPair const &l, SortPair const &r) {
|
||||
auto li = thrust::get<0>(l);
|
||||
auto ri = thrust::get<0>(r);
|
||||
|
||||
auto const &input = l_is_left ? left : right;
|
||||
l -= (l_is_left ? 0 : input.feature_values.size());
|
||||
r -= (r_is_left ? 0 : input.feature_values.size());
|
||||
auto l_is_left = li < left.feature_values.size();
|
||||
auto r_is_left = ri < left.feature_values.size();
|
||||
|
||||
auto lfidx = dh::SegmentId(input.feature_segments, l);
|
||||
auto rfidx = dh::SegmentId(input.feature_segments, r);
|
||||
if (lfidx != rfidx) {
|
||||
return lfidx < rfidx; // not the same feature
|
||||
}
|
||||
if (common::IsCat(input.feature_types, lfidx)) {
|
||||
auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[l]);
|
||||
auto rw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[r]);
|
||||
return lw < rw;
|
||||
}
|
||||
return l < r;
|
||||
});
|
||||
if (l_is_left != r_is_left) {
|
||||
return l_is_left; // not the same node
|
||||
}
|
||||
|
||||
auto const &input = l_is_left ? left : right;
|
||||
li -= (l_is_left ? 0 : input.feature_values.size());
|
||||
ri -= (r_is_left ? 0 : input.feature_values.size());
|
||||
|
||||
auto lfidx = d_feature_idx[li];
|
||||
auto rfidx = d_feature_idx[ri];
|
||||
|
||||
if (lfidx != rfidx) {
|
||||
return lfidx < rfidx; // not the same feature
|
||||
}
|
||||
|
||||
if (common::IsCat(input.feature_types, lfidx)) {
|
||||
auto lw = thrust::get<1>(l);
|
||||
auto rw = thrust::get<1>(r);
|
||||
return lw < rw;
|
||||
}
|
||||
return li < ri;
|
||||
});
|
||||
return dh::ToSpan(cat_sorted_idx_);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user