De-duplicate GPU parameters. (#4454)
* Only define `gpu_id` and `n_gpus` in `LearnerTrainParam` * Pass LearnerTrainParam through XGBoost vid factory method. * Disable all GPU usage when GPU related parameters are not specified (fixes XGBoost choosing GPU over aggressively). * Test learner train param io. * Fix gpu pickling.
This commit is contained in:
@@ -20,20 +20,6 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
|
||||
/*! \brief prediction parameters */
|
||||
struct GPUPredictionParam : public dmlc::Parameter<GPUPredictionParam> {
|
||||
int gpu_id;
|
||||
int n_gpus;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUPredictionParam) {
|
||||
DMLC_DECLARE_FIELD(gpu_id).set_lower_bound(0).set_default(0).describe(
|
||||
"Device ordinal for GPU prediction.");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_lower_bound(-1).set_default(1).describe(
|
||||
"Number of devices to use for prediction.");
|
||||
}
|
||||
};
|
||||
DMLC_REGISTER_PARAMETER(GPUPredictionParam);
|
||||
|
||||
template <typename IterT>
|
||||
void IncrementOffset(IterT begin_itr, IterT end_itr, size_t amount) {
|
||||
thrust::transform(begin_itr, end_itr, begin_itr,
|
||||
@@ -387,14 +373,15 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
public:
|
||||
GPUPredictor() // NOLINT
|
||||
: cpu_predictor_(Predictor::Create("cpu_predictor")) {} // NOLINT
|
||||
GPUPredictor()
|
||||
: cpu_predictor_(Predictor::Create("cpu_predictor", learner_param_)) {}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
GPUSet devices = GPUSet::All(
|
||||
param_.gpu_id, param_.n_gpus, dmat->Info().num_row_);
|
||||
GPUSet devices = GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus,
|
||||
dmat->Info().num_row_);
|
||||
CHECK_NE(devices.Size(), 0);
|
||||
ConfigureShards(devices);
|
||||
|
||||
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
|
||||
@@ -508,9 +495,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
const std::vector<std::shared_ptr<DMatrix>>& cache) override {
|
||||
Predictor::Init(cfg, cache);
|
||||
cpu_predictor_->Init(cfg, cache);
|
||||
param_.InitAllowUnknown(cfg);
|
||||
|
||||
GPUSet devices = GPUSet::All(param_.gpu_id, param_.n_gpus);
|
||||
GPUSet devices = GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus);
|
||||
ConfigureShards(devices);
|
||||
}
|
||||
|
||||
@@ -527,7 +513,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
});
|
||||
}
|
||||
|
||||
GPUPredictionParam param_;
|
||||
std::unique_ptr<Predictor> cpu_predictor_;
|
||||
std::vector<DeviceShard> shards_;
|
||||
GPUSet devices_;
|
||||
|
||||
Reference in New Issue
Block a user