CPU evaluation for cat data. (#7393)
* Implementation for one hot based. * Implementation for partition based. (LightGBM)
This commit is contained in:
@@ -172,12 +172,10 @@ SimpleLCG::StateType SimpleLCG::operator()() {
|
||||
state_ = (alpha_ * state_) % mod_;
|
||||
return state_;
|
||||
}
|
||||
SimpleLCG::StateType SimpleLCG::Min() const {
|
||||
return seed_ * alpha_;
|
||||
}
|
||||
SimpleLCG::StateType SimpleLCG::Max() const {
|
||||
return max_value_;
|
||||
}
|
||||
SimpleLCG::StateType SimpleLCG::Min() const { return min(); }
|
||||
SimpleLCG::StateType SimpleLCG::Max() const { return max(); }
|
||||
// Make sure it's compile time constant.
|
||||
static_assert(SimpleLCG::max() - SimpleLCG::min(), "");
|
||||
|
||||
void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
|
||||
@@ -291,6 +289,7 @@ void RandomDataGenerator::GenerateCSR(
|
||||
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
|
||||
float sparsity = sparsity_ * (upper_ - lower_) + lower_;
|
||||
SimpleRealUniformDistribution<bst_float> cat(0.0, max_cat_);
|
||||
|
||||
h_rptr.emplace_back(0);
|
||||
for (size_t i = 0; i < rows_; ++i) {
|
||||
@@ -298,7 +297,11 @@ void RandomDataGenerator::GenerateCSR(
|
||||
for (size_t j = 0; j < cols_; ++j) {
|
||||
auto g = dist(&lcg);
|
||||
if (g >= sparsity) {
|
||||
g = dist(&lcg);
|
||||
if (common::IsCat(ft_, j)) {
|
||||
g = common::AsCat(cat(&lcg));
|
||||
} else {
|
||||
g = dist(&lcg);
|
||||
}
|
||||
h_value.emplace_back(g);
|
||||
rptr++;
|
||||
h_cols.emplace_back(j);
|
||||
@@ -347,11 +350,15 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
||||
}
|
||||
if (device_ >= 0) {
|
||||
out->Info().labels_.SetDevice(device_);
|
||||
out->Info().feature_types.SetDevice(device_);
|
||||
for (auto const& page : out->GetBatches<SparsePage>()) {
|
||||
page.data.SetDevice(device_);
|
||||
page.offset.SetDevice(device_);
|
||||
}
|
||||
}
|
||||
if (!ft_.empty()) {
|
||||
out->Info().feature_types.HostVector() = ft_;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user