Fix calling GPU predictor (#4836)

* Fix calling GPU predictor
This commit is contained in:
Jiaming Yuan
2019-09-05 19:09:38 -04:00
committed by GitHub
parent 52d44e07fe
commit a5f232feb8
5 changed files with 85 additions and 5 deletions

View File

@@ -49,6 +49,7 @@ class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
};
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
auto begin_iter = BatchIterator<SparsePage>(
new SimpleBatchIteratorImpl<SparsePage>(&(cast->page_)));

View File

@@ -191,7 +191,7 @@ class GBTree : public GradientBooster {
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit) override {
CHECK(configured_);
GetPredictor()->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
GetPredictor(out_preds, p_fmat)->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
}
void PredictInstance(const SparsePage::Inst& inst,
@@ -242,8 +242,22 @@ class GBTree : public GradientBooster {
int bst_group,
std::vector<std::unique_ptr<RegTree> >* ret);
std::unique_ptr<Predictor> const& GetPredictor() const {
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
DMatrix* f_dmat = nullptr) const {
CHECK(configured_);
// GPU_Hist by default has prediction cache calculated from quantile values, so GPU
// Predictor is not used for training dataset. But when XGBoost performs continue
// training with an existing model, the prediction cache is not availbale and number
// of tree doesn't equal zero, the whole training dataset got copied into GPU for
// precise prediction. This condition tries to avoid such copy by calling CPU
// Predictor.
if ((out_pred && out_pred->Size() == 0) &&
(model_.param.num_trees != 0) &&
// FIXME(trivialfis): Implement a better method for testing whether data is on
// device after DMatrix refactoring is done.
(f_dmat && !((*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead()))) {
return cpu_predictor_;
}
if (tparam_.predictor == "cpu_predictor") {
CHECK(cpu_predictor_);
return cpu_predictor_;

View File

@@ -134,7 +134,7 @@ class CPUPredictor : public Predictor {
} else {
if (!base_margin.empty()) {
std::ostringstream oss;
oss << "Warning: Ignoring the base margin, since it has incorrect length. "
oss << "Ignoring the base margin, since it has incorrect length. "
<< "The base margin must be an array of length ";
if (model.param.num_output_group > 1) {
oss << "[num_class] * [number of data points], i.e. "
@@ -145,7 +145,7 @@ class CPUPredictor : public Predictor {
}
oss << "Instead, all data points will use "
<< "base_score = " << model.base_margin;
LOG(INFO) << oss.str();
LOG(WARNING) << oss.str();
}
std::fill(out_preds_h.begin(), out_preds_h.end(), model.base_margin);
}