Remove column major specialization. (#5755)

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-06-05 16:19:14 +08:00
committed by GitHub
parent bd9d57f579
commit cacff9232a
10 changed files with 70 additions and 204 deletions

View File

@@ -271,12 +271,12 @@ class CPUPredictor : public Predictor {
PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const {
auto threads = omp_get_max_threads();
auto m = dmlc::get<Adapter>(x);
CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature)
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model.";
MetaInfo info;
info.num_col_ = m.NumColumns();
info.num_row_ = m.NumRows();
info.num_col_ = m->NumColumns();
info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model);
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
auto &predictions = out_preds->predictions.HostVector();
@@ -284,17 +284,17 @@ class CPUPredictor : public Predictor {
InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp);
size_t constexpr kUnroll = 8;
PredictBatchKernel(AdapterView<Adapter, kUnroll>(
&m, missing, common::Span<Entry>{workspace}),
m.get(), missing, common::Span<Entry>{workspace}),
&predictions, model, tree_begin, tree_end, &thread_temp);
}
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
float missing, PredictionCacheEntry *out_preds,
uint32_t tree_begin, unsigned tree_end) const override {
if (x.type() == typeid(data::DenseAdapter)) {
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
this->DispatchedInplacePredict<data::DenseAdapter>(
x, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(data::CSRAdapter)) {
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
this->DispatchedInplacePredict<data::CSRAdapter>(
x, model, missing, out_preds, tree_begin, tree_end);
} else {