Remove column major specialization. (#5755)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -271,12 +271,12 @@ class CPUPredictor : public Predictor {
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
auto threads = omp_get_max_threads();
|
||||
auto m = dmlc::get<Adapter>(x);
|
||||
CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature)
|
||||
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
MetaInfo info;
|
||||
info.num_col_ = m.NumColumns();
|
||||
info.num_row_ = m.NumRows();
|
||||
info.num_col_ = m->NumColumns();
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
@@ -284,17 +284,17 @@ class CPUPredictor : public Predictor {
|
||||
InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp);
|
||||
size_t constexpr kUnroll = 8;
|
||||
PredictBatchKernel(AdapterView<Adapter, kUnroll>(
|
||||
&m, missing, common::Span<Entry>{workspace}),
|
||||
m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
if (x.type() == typeid(data::DenseAdapter)) {
|
||||
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(data::CSRAdapter)) {
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user