Thread safe, inplace prediction. (#5389)
Normal prediction with DMatrix is now thread safe with locks. Added inplace prediction is lock free thread safe. When data is on device (cupy, cudf), the returned data is also on device. * Implementation for numpy, csr, cudf and cupy. * Implementation for dask. * Remove sync in simple dmatrix.
This commit is contained in:
@@ -8,6 +8,8 @@
|
||||
#include <dmlc/parameter.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
@@ -18,6 +20,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dmlc/any.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/model.h"
|
||||
@@ -205,7 +208,7 @@ class LearnerConfiguration : public Learner {
|
||||
PredictionContainer cache_;
|
||||
|
||||
protected:
|
||||
bool need_configuration_;
|
||||
std::atomic<bool> need_configuration_;
|
||||
std::map<std::string, std::string> cfg_;
|
||||
// Stores information like best-iteration for early stopping.
|
||||
std::map<std::string, std::string> attributes_;
|
||||
@@ -214,6 +217,7 @@ class LearnerConfiguration : public Learner {
|
||||
LearnerModelParam learner_model_param_;
|
||||
LearnerTrainParam tparam_;
|
||||
std::vector<std::string> metric_names_;
|
||||
std::mutex config_lock_;
|
||||
|
||||
public:
|
||||
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
|
||||
@@ -226,6 +230,9 @@ class LearnerConfiguration : public Learner {
|
||||
// Configuration before data is known.
|
||||
|
||||
void Configure() override {
|
||||
// Varient of double checked lock
|
||||
if (!this->need_configuration_) { return; }
|
||||
std::lock_guard<std::mutex> gard(config_lock_);
|
||||
if (!this->need_configuration_) { return; }
|
||||
|
||||
monitor_.Start("Configure");
|
||||
@@ -1003,6 +1010,23 @@ class LearnerImpl : public LearnerIO {
|
||||
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
|
||||
return (*XGBAPIThreadLocalStore::Get())[this];
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, std::string const &type,
|
||||
float missing, HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin = 0, uint32_t layer_end = 0) override {
|
||||
this->Configure();
|
||||
auto& out_predictions = this->GetThreadLocal().prediction_entry;
|
||||
this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin,
|
||||
layer_end);
|
||||
if (type == "value") {
|
||||
obj_->PredTransform(&out_predictions.predictions);
|
||||
} else if (type == "margin") {
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported prediction type:" << type;
|
||||
}
|
||||
*out_preds = &out_predictions.predictions;
|
||||
}
|
||||
|
||||
const std::map<std::string, std::string>& GetConfigurationArguments() const override {
|
||||
return cfg_;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user