parent
b1b6246e35
commit
c2508814ff
@ -74,10 +74,15 @@ inline void InvalidCategory() {
|
|||||||
// values to be less than this last representable value.
|
// values to be less than this last representable value.
|
||||||
auto str = std::to_string(OutOfRangeCat());
|
auto str = std::to_string(OutOfRangeCat());
|
||||||
LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, "
|
LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, "
|
||||||
"less than total umber of categories in training data and less than " +
|
"less than total number of categories in training data and less than " +
|
||||||
str;
|
str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void CheckMaxCat(float max_cat, size_t n_categories) {
|
||||||
|
CHECK_GE(max_cat + 1, n_categories)
|
||||||
|
<< "Maximum cateogry should not be lesser than the total number of categories.";
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Whether should we use onehot encoding for categorical data.
|
* \brief Whether should we use onehot encoding for categorical data.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -164,6 +164,74 @@ class Range {
|
|||||||
Iterator end_;
|
Iterator end_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Transform iterator that takes an index and calls transform operator.
|
||||||
|
*
|
||||||
|
* This is CPU-only right now as taking host device function as operator complicates the
|
||||||
|
* code. For device side one can use `thrust::transform_iterator` instead.
|
||||||
|
*/
|
||||||
|
template <typename Fn>
|
||||||
|
class IndexTransformIter {
|
||||||
|
size_t iter_{0};
|
||||||
|
Fn fn_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using iterator_category = std::random_access_iterator_tag; // NOLINT
|
||||||
|
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
|
||||||
|
using difference_type = detail::ptrdiff_t; // NOLINT
|
||||||
|
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
|
||||||
|
using pointer = std::add_pointer_t<value_type>; // NOLINT
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* \param op Transform operator, takes a size_t index as input.
|
||||||
|
*/
|
||||||
|
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
|
||||||
|
IndexTransformIter(IndexTransformIter const &) = default;
|
||||||
|
IndexTransformIter& operator=(IndexTransformIter&&) = default;
|
||||||
|
IndexTransformIter& operator=(IndexTransformIter const& that) {
|
||||||
|
iter_ = that.iter_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
value_type operator*() const { return fn_(iter_); }
|
||||||
|
|
||||||
|
auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
|
||||||
|
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
|
||||||
|
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }
|
||||||
|
|
||||||
|
IndexTransformIter &operator++() {
|
||||||
|
iter_++;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
IndexTransformIter operator++(int) {
|
||||||
|
auto ret = *this;
|
||||||
|
++(*this);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
IndexTransformIter &operator+=(difference_type n) {
|
||||||
|
iter_ += n;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
IndexTransformIter &operator-=(difference_type n) {
|
||||||
|
(*this) += -n;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
IndexTransformIter operator+(difference_type n) const {
|
||||||
|
auto ret = *this;
|
||||||
|
return ret += n;
|
||||||
|
}
|
||||||
|
IndexTransformIter operator-(difference_type n) const {
|
||||||
|
auto ret = *this;
|
||||||
|
return ret -= n;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
auto MakeIndexTransformIter(Fn&& fn) {
|
||||||
|
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
|
||||||
|
}
|
||||||
|
|
||||||
int AllVisibleGPUs();
|
int AllVisibleGPUs();
|
||||||
|
|
||||||
inline void AssertGPUSupport() {
|
inline void AssertGPUSupport() {
|
||||||
|
|||||||
@ -468,11 +468,17 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
||||||
auto &cut_values = cuts->cut_values_.HostVector();
|
if (std::any_of(categories.cbegin(), categories.cend(), InvalidCat)) {
|
||||||
for (auto const &v : categories) {
|
InvalidCategory();
|
||||||
cut_values.push_back(AsCat(v));
|
|
||||||
}
|
}
|
||||||
|
auto &cut_values = cuts->cut_values_.HostVector();
|
||||||
|
auto max_cat = *std::max_element(categories.cbegin(), categories.cend());
|
||||||
|
CheckMaxCat(max_cat, categories.size());
|
||||||
|
for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) {
|
||||||
|
cut_values.push_back(i);
|
||||||
|
}
|
||||||
|
return max_cat;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
@ -505,11 +511,12 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
float max_cat{-1.f};
|
||||||
for (size_t fid = 0; fid < reduced.size(); ++fid) {
|
for (size_t fid = 0; fid < reduced.size(); ++fid) {
|
||||||
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
|
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
|
||||||
typename WQSketch::SummaryContainer const& a = final_summaries[fid];
|
typename WQSketch::SummaryContainer const& a = final_summaries[fid];
|
||||||
if (IsCat(feature_types_, fid)) {
|
if (IsCat(feature_types_, fid)) {
|
||||||
AddCategories(categories_.at(fid), cuts);
|
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), cuts));
|
||||||
} else {
|
} else {
|
||||||
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
|
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
|
||||||
// push a value that is greater than anything
|
// push a value that is greater than anything
|
||||||
@ -527,30 +534,7 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
|
|||||||
cuts->cut_ptrs_.HostVector().push_back(cut_size);
|
cuts->cut_ptrs_.HostVector().push_back(cut_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_categorical_) {
|
cuts->SetCategorical(this->has_categorical_, max_cat);
|
||||||
for (auto const &feat : categories_) {
|
|
||||||
if (std::any_of(feat.cbegin(), feat.cend(), InvalidCat)) {
|
|
||||||
InvalidCategory();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto const &ptrs = cuts->Ptrs();
|
|
||||||
auto const &vals = cuts->Values();
|
|
||||||
|
|
||||||
float max_cat{-std::numeric_limits<float>::infinity()};
|
|
||||||
for (size_t i = 1; i < ptrs.size(); ++i) {
|
|
||||||
if (IsCat(feature_types_, i - 1)) {
|
|
||||||
auto beg = ptrs[i - 1];
|
|
||||||
auto end = ptrs[i];
|
|
||||||
auto feat = Span<float const>{vals}.subspan(beg, end - beg);
|
|
||||||
auto max_elem = *std::max_element(feat.cbegin(), feat.cend());
|
|
||||||
if (max_elem > max_cat) {
|
|
||||||
max_cat = max_elem;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cuts->SetCategorical(true, max_cat);
|
|
||||||
}
|
|
||||||
|
|
||||||
monitor_.Stop(__func__);
|
monitor_.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2020 by XGBoost Contributors
|
* Copyright 2020-2022 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
@ -583,13 +583,13 @@ void SketchContainer::AllReduce() {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct InvalidCatOp {
|
struct InvalidCatOp {
|
||||||
Span<float const> values;
|
Span<SketchEntry const> values;
|
||||||
Span<uint32_t const> ptrs;
|
Span<size_t const> ptrs;
|
||||||
Span<FeatureType const> ft;
|
Span<FeatureType const> ft;
|
||||||
|
|
||||||
XGBOOST_DEVICE bool operator()(size_t i) const {
|
XGBOOST_DEVICE bool operator()(size_t i) const {
|
||||||
auto fidx = dh::SegmentId(ptrs, i);
|
auto fidx = dh::SegmentId(ptrs, i);
|
||||||
return IsCat(ft, fidx) && InvalidCat(values[i]);
|
return IsCat(ft, fidx) && InvalidCat(values[i].value);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
@ -611,7 +611,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
|||||||
|
|
||||||
p_cuts->min_vals_.SetDevice(device_);
|
p_cuts->min_vals_.SetDevice(device_);
|
||||||
auto d_min_values = p_cuts->min_vals_.DeviceSpan();
|
auto d_min_values = p_cuts->min_vals_.DeviceSpan();
|
||||||
auto in_cut_values = dh::ToSpan(this->Current());
|
auto const in_cut_values = dh::ToSpan(this->Current());
|
||||||
|
|
||||||
// Set up output ptr
|
// Set up output ptr
|
||||||
p_cuts->cut_ptrs_.SetDevice(device_);
|
p_cuts->cut_ptrs_.SetDevice(device_);
|
||||||
@ -619,26 +619,70 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
|||||||
h_out_columns_ptr.clear();
|
h_out_columns_ptr.clear();
|
||||||
h_out_columns_ptr.push_back(0);
|
h_out_columns_ptr.push_back(0);
|
||||||
auto const& h_feature_types = this->feature_types_.ConstHostSpan();
|
auto const& h_feature_types = this->feature_types_.ConstHostSpan();
|
||||||
for (bst_feature_t i = 0; i < num_columns_; ++i) {
|
|
||||||
size_t column_size = std::max(static_cast<size_t>(1ul),
|
auto d_ft = feature_types_.ConstDeviceSpan();
|
||||||
this->Column(i).size());
|
|
||||||
if (IsCat(h_feature_types, i)) {
|
std::vector<SketchEntry> max_values;
|
||||||
h_out_columns_ptr.push_back(static_cast<size_t>(column_size));
|
float max_cat{-1.f};
|
||||||
} else {
|
if (has_categorical_) {
|
||||||
h_out_columns_ptr.push_back(std::min(static_cast<size_t>(column_size),
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
static_cast<size_t>(num_bins_)));
|
auto key_it = dh::MakeTransformIterator<bst_feature_t>(
|
||||||
|
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t {
|
||||||
|
return dh::SegmentId(d_in_columns_ptr, i);
|
||||||
|
});
|
||||||
|
auto invalid_op = InvalidCatOp{in_cut_values, d_in_columns_ptr, d_ft};
|
||||||
|
auto val_it = dh::MakeTransformIterator<SketchEntry>(
|
||||||
|
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
|
||||||
|
auto fidx = dh::SegmentId(d_in_columns_ptr, i);
|
||||||
|
auto v = in_cut_values[i];
|
||||||
|
if (IsCat(d_ft, fidx)) {
|
||||||
|
if (invalid_op(i)) {
|
||||||
|
// use inf to indicate invalid value, this way we can keep it as in
|
||||||
|
// indicator in the reduce operation as it's always the greatest value.
|
||||||
|
v.value = std::numeric_limits<float>::infinity();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
});
|
||||||
|
CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1);
|
||||||
|
max_values.resize(d_in_columns_ptr.size() - 1);
|
||||||
|
dh::caching_device_vector<SketchEntry> d_max_values(d_in_columns_ptr.size() - 1);
|
||||||
|
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it,
|
||||||
|
thrust::make_discard_iterator(), d_max_values.begin(),
|
||||||
|
thrust::equal_to<bst_feature_t>{},
|
||||||
|
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
|
||||||
|
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_values));
|
||||||
|
auto max_it = common::MakeIndexTransformIter([&](auto i) {
|
||||||
|
if (IsCat(h_feature_types, i)) {
|
||||||
|
return max_values[i].value;
|
||||||
|
}
|
||||||
|
return -1.f;
|
||||||
|
});
|
||||||
|
max_cat = *std::max_element(max_it, max_it + max_values.size());
|
||||||
|
if (std::isinf(max_cat)) {
|
||||||
|
InvalidCategory();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(),
|
|
||||||
h_out_columns_ptr.begin());
|
|
||||||
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
|
||||||
|
|
||||||
// Set up output cuts
|
// Set up output cuts
|
||||||
|
for (bst_feature_t i = 0; i < num_columns_; ++i) {
|
||||||
|
size_t column_size = std::max(static_cast<size_t>(1ul), this->Column(i).size());
|
||||||
|
if (IsCat(h_feature_types, i)) {
|
||||||
|
// column_size is the number of unique values in that feature.
|
||||||
|
CheckMaxCat(max_values[i].value, column_size);
|
||||||
|
h_out_columns_ptr.push_back(max_values[i].value + 1); // includes both max_cat and 0.
|
||||||
|
} else {
|
||||||
|
h_out_columns_ptr.push_back(
|
||||||
|
std::min(static_cast<size_t>(column_size), static_cast<size_t>(num_bins_)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), h_out_columns_ptr.begin());
|
||||||
|
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
||||||
|
|
||||||
size_t total_bins = h_out_columns_ptr.back();
|
size_t total_bins = h_out_columns_ptr.back();
|
||||||
p_cuts->cut_values_.SetDevice(device_);
|
p_cuts->cut_values_.SetDevice(device_);
|
||||||
p_cuts->cut_values_.Resize(total_bins);
|
p_cuts->cut_values_.Resize(total_bins);
|
||||||
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();
|
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();
|
||||||
auto d_ft = feature_types_.ConstDeviceSpan();
|
|
||||||
|
|
||||||
dh::LaunchN(total_bins, [=] __device__(size_t idx) {
|
dh::LaunchN(total_bins, [=] __device__(size_t idx) {
|
||||||
auto column_id = dh::SegmentId(d_out_columns_ptr, idx);
|
auto column_id = dh::SegmentId(d_out_columns_ptr, idx);
|
||||||
@ -667,8 +711,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (IsCat(d_ft, column_id)) {
|
if (IsCat(d_ft, column_id)) {
|
||||||
assert(out_column.size() == in_column.size());
|
out_column[idx] = idx;
|
||||||
out_column[idx] = in_column[idx].value;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -684,36 +727,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
|||||||
out_column[idx] = in_column[idx+1].value;
|
out_column[idx] = in_column[idx+1].value;
|
||||||
});
|
});
|
||||||
|
|
||||||
float max_cat{-1.0f};
|
|
||||||
if (has_categorical_) {
|
|
||||||
auto invalid_op = InvalidCatOp{out_cut_values, d_out_columns_ptr, d_ft};
|
|
||||||
auto it = dh::MakeTransformIterator<thrust::pair<bool, float>>(
|
|
||||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
|
|
||||||
auto fidx = dh::SegmentId(d_out_columns_ptr, i);
|
|
||||||
if (IsCat(d_ft, fidx)) {
|
|
||||||
auto invalid = invalid_op(i);
|
|
||||||
auto v = out_cut_values[i];
|
|
||||||
return thrust::make_pair(invalid, v);
|
|
||||||
}
|
|
||||||
return thrust::make_pair(false, std::numeric_limits<float>::min());
|
|
||||||
});
|
|
||||||
|
|
||||||
bool invalid{false};
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
thrust::tie(invalid, max_cat) =
|
|
||||||
thrust::reduce(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
|
|
||||||
thrust::make_pair(false, std::numeric_limits<float>::min()),
|
|
||||||
[=] XGBOOST_DEVICE(thrust::pair<bool, bst_cat_t> const &l,
|
|
||||||
thrust::pair<bool, bst_cat_t> const &r) {
|
|
||||||
return thrust::make_pair(l.first || r.first, std::max(l.second, r.second));
|
|
||||||
});
|
|
||||||
if (invalid) {
|
|
||||||
InvalidCategory();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p_cuts->SetCategorical(this->has_categorical_, max_cat);
|
p_cuts->SetCategorical(this->has_categorical_, max_cat);
|
||||||
|
|
||||||
timer_.Stop(__func__);
|
timer_.Stop(__func__);
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -61,6 +61,9 @@ class TestGPUUpdaters:
|
|||||||
def test_categorical(self, rows, cols, rounds, cats):
|
def test_categorical(self, rows, cols, rounds, cats):
|
||||||
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
||||||
|
|
||||||
|
def test_max_cat(self) -> None:
|
||||||
|
self.cputest.run_max_cat("gpu_hist")
|
||||||
|
|
||||||
def test_categorical_32_cat(self):
|
def test_categorical_32_cat(self):
|
||||||
'''32 hits the bound of integer bitset, so special test'''
|
'''32 hits the bound of integer bitset, so special test'''
|
||||||
rows = 1000
|
rows = 1000
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from random import choice
|
||||||
|
from string import ascii_lowercase
|
||||||
import testing as tm
|
import testing as tm
|
||||||
import pytest
|
import pytest
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
@ -167,6 +169,30 @@ class TestTreeMethod:
|
|||||||
|
|
||||||
def test_invalid_category(self) -> None:
|
def test_invalid_category(self) -> None:
|
||||||
self.run_invalid_category("approx")
|
self.run_invalid_category("approx")
|
||||||
|
self.run_invalid_category("hist")
|
||||||
|
|
||||||
|
def run_max_cat(self, tree_method: str) -> None:
|
||||||
|
"""Test data with size smaller than number of categories."""
|
||||||
|
import pandas as pd
|
||||||
|
n_cat = 100
|
||||||
|
n = 5
|
||||||
|
X = pd.Series(
|
||||||
|
["".join(choice(ascii_lowercase) for i in range(3)) for i in range(n_cat)],
|
||||||
|
dtype="category",
|
||||||
|
)[:n].to_frame()
|
||||||
|
|
||||||
|
reg = xgb.XGBRegressor(
|
||||||
|
enable_categorical=True,
|
||||||
|
tree_method=tree_method,
|
||||||
|
n_estimators=10,
|
||||||
|
)
|
||||||
|
y = pd.Series(range(n))
|
||||||
|
reg.fit(X=X, y=y, eval_set=[(X, y)])
|
||||||
|
assert tm.non_increasing(reg.evals_result()["validation_0"]["rmse"])
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||||
|
def test_max_cat(self, tree_method) -> None:
|
||||||
|
self.run_max_cat(tree_method)
|
||||||
|
|
||||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user