CPU evaluation for cat data. (#7393)

* Implementation for one hot based.
* Implementation for partition based. (LightGBM)
This commit is contained in:
Jiaming Yuan
2021-11-06 14:41:35 +08:00
committed by GitHub
parent 6ede12412c
commit d7d1b6e3a6
15 changed files with 540 additions and 166 deletions

View File

@@ -172,12 +172,10 @@ SimpleLCG::StateType SimpleLCG::operator()() {
state_ = (alpha_ * state_) % mod_;
return state_;
}
SimpleLCG::StateType SimpleLCG::Min() const {
return seed_ * alpha_;
}
SimpleLCG::StateType SimpleLCG::Max() const {
return max_value_;
}
SimpleLCG::StateType SimpleLCG::Min() const { return min(); }
SimpleLCG::StateType SimpleLCG::Max() const { return max(); }
// Make sure it's compile time constant.
static_assert(SimpleLCG::max() - SimpleLCG::min(), "");
void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
@@ -291,6 +289,7 @@ void RandomDataGenerator::GenerateCSR(
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
float sparsity = sparsity_ * (upper_ - lower_) + lower_;
SimpleRealUniformDistribution<bst_float> cat(0.0, max_cat_);
h_rptr.emplace_back(0);
for (size_t i = 0; i < rows_; ++i) {
@@ -298,7 +297,11 @@ void RandomDataGenerator::GenerateCSR(
for (size_t j = 0; j < cols_; ++j) {
auto g = dist(&lcg);
if (g >= sparsity) {
g = dist(&lcg);
if (common::IsCat(ft_, j)) {
g = common::AsCat(cat(&lcg));
} else {
g = dist(&lcg);
}
h_value.emplace_back(g);
rptr++;
h_cols.emplace_back(j);
@@ -347,11 +350,15 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
}
if (device_ >= 0) {
out->Info().labels_.SetDevice(device_);
out->Info().feature_types.SetDevice(device_);
for (auto const& page : out->GetBatches<SparsePage>()) {
page.data.SetDevice(device_);
page.offset.SetDevice(device_);
}
}
if (!ft_.empty()) {
out->Info().feature_types.HostVector() = ft_;
}
return out;
}