Remove single_precision_histogram for gpu_hist (#7828)

This commit is contained in:
Rory Mitchell
2022-05-03 14:53:19 +02:00
committed by GitHub
parent 50d854e02e
commit 90cce38236
17 changed files with 97 additions and 155 deletions

View File

@@ -20,8 +20,7 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const
if (e == nullptr) {
LOG(FATAL) << "Unknown tree updater " << name;
}
auto p_updater = (e->body)(task);
p_updater->ctx_ = tparam;
auto p_updater = (e->body)(tparam, task);
return p_updater;
}

View File

@@ -268,7 +268,10 @@ class GlobalApproxUpdater : public TreeUpdater {
ObjInfo task_;
public:
explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); }
explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task)
: task_{task}, TreeUpdater(ctx) {
monitor_.Init(__func__);
}
void Configure(const Args &args) override {
param_.UpdateAllowUnknown(args);
@@ -365,6 +368,8 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
.describe(
"Tree constructor that uses approximate histogram construction "
"for each node.")
.set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); });
.set_body([](GenericParameter const *ctx, ObjInfo task) {
return new GlobalApproxUpdater(ctx, task);
});
} // namespace tree
} // namespace xgboost

View File

@@ -33,11 +33,10 @@ namespace tree {
* \brief base tree maker class that defines common operation
* needed in tree making
*/
class BaseMaker: public TreeUpdater {
class BaseMaker : public TreeUpdater {
public:
void Configure(const Args& args) override {
param_.UpdateAllowUnknown(args);
}
explicit BaseMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {}
void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);

View File

@@ -57,7 +57,8 @@ DMLC_REGISTER_PARAMETER(ColMakerTrainParam);
/*! \brief column-wise update to construct a tree */
class ColMaker: public TreeUpdater {
public:
void Configure(const Args& args) override {
explicit ColMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {}
void Configure(const Args &args) override {
param_.UpdateAllowUnknown(args);
colmaker_param_.UpdateAllowUnknown(args);
}
@@ -614,8 +615,8 @@ class ColMaker: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
.describe("Grow tree with parallelization over columns.")
.set_body([](ObjInfo) {
return new ColMaker();
.set_body([](GenericParameter const* ctx, ObjInfo) {
return new ColMaker(ctx);
});
} // namespace tree
} // namespace xgboost

View File

@@ -50,12 +50,9 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
// training parameters specific to this algorithm
struct GPUHistMakerTrainParam
: public XGBoostParameter<GPUHistMakerTrainParam> {
bool single_precision_histogram;
bool debug_synchronize;
// declare parameters
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
"Check if all distributed tree are identical after tree construction.");
}
@@ -557,6 +554,13 @@ struct GPUHistMakerDevice {
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree;
// Sanity check - have we created a leaf with no training instances?
if (!rabit::IsDistributed() && row_partitioner) {
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
<< "No training instances in this leaf!";
}
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
auto base_weight = candidate.base_weight;
auto left_weight = candidate.left_weight * param.learning_rate;
@@ -702,26 +706,42 @@ struct GPUHistMakerDevice {
}
};
template <typename GradientSumT>
class GPUHistMakerSpecialised {
class GPUHistMaker : public TreeUpdater {
using GradientSumT = GradientPairPrecise;
public:
explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {};
void Configure(const Args& args, GenericParameter const* generic_param) {
explicit GPUHistMaker(GenericParameter const* ctx, ObjInfo task)
: TreeUpdater(ctx), task_{task} {};
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
param_.UpdateAllowUnknown(args);
ctx_ = generic_param;
hist_maker_param_.UpdateAllowUnknown(args);
dh::CheckComputeCapability();
initialised_ = false;
monitor_.Init("updater_gpu_hist");
}
~GPUHistMakerSpecialised() { // NOLINT
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
initialised_ = false;
FromJson(config.at("train_param"), &param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
out["train_param"] = ToJson(param_);
}
~GPUHistMaker() { // NOLINT
dh::GlobalMemoryLogger().Log();
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) {
const std::vector<RegTree*>& trees) override {
monitor_.Start("Update");
// rescale learning rate according to size of trees
@@ -791,7 +811,7 @@ class GPUHistMakerSpecialised {
}
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
RegTree reference_tree {}; // rank 0 tree
RegTree reference_tree{}; // rank 0 tree
reference_tree.Load(&fs);
CHECK(*local_tree == reference_tree);
}
@@ -806,8 +826,8 @@ class GPUHistMakerSpecialised {
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position);
}
bool UpdatePredictionCache(const DMatrix *data,
linalg::VectorView<bst_float> p_out_preds) {
bool UpdatePredictionCache(const DMatrix* data,
linalg::VectorView<bst_float> p_out_preds) override {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
@@ -817,109 +837,33 @@ class GPUHistMakerSpecialised {
return true;
}
TrainParam param_; // NOLINT
MetaInfo* info_{}; // NOLINT
TrainParam param_; // NOLINT
MetaInfo* info_{}; // NOLINT
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
char const* Name() const override { return "grow_gpu_hist"; }
private:
bool initialised_ { false };
bool initialised_{false};
GPUHistMakerTrainParam hist_maker_param_;
Context const* ctx_;
dh::AllReducer reducer_;
DMatrix* p_last_fmat_ { nullptr };
DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr};
ObjInfo task_;
common::Monitor monitor_;
};
class GPUHistMaker : public TreeUpdater {
public:
explicit GPUHistMaker(ObjInfo task) : task_{task} {}
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
hist_maker_param_.UpdateAllowUnknown(args);
// The passed in args can be empty, if we simply purge the old maker without
// preserving parameters then we can't do Update on it.
TrainParam param;
if (float_maker_) {
param = float_maker_->param_;
} else if (double_maker_) {
param = double_maker_->param_;
}
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
float_maker_->param_ = param;
float_maker_->Configure(args, ctx_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
double_maker_->param_ = param;
double_maker_->Configure(args, ctx_);
}
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
FromJson(config.at("train_param"), &float_maker_->param_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
FromJson(config.at("train_param"), &double_maker_->param_);
}
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
out["train_param"] = ToJson(float_maker_->param_);
} else {
out["train_param"] = ToJson(double_maker_->param_);
}
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override {
if (hist_maker_param_.single_precision_histogram) {
float_maker_->Update(gpair, dmat, out_position, trees);
} else {
double_maker_->Update(gpair, dmat, out_position, trees);
}
}
bool UpdatePredictionCache(const DMatrix* data,
linalg::VectorView<bst_float> p_out_preds) override {
if (hist_maker_param_.single_precision_histogram) {
return float_maker_->UpdatePredictionCache(data, p_out_preds);
} else {
return double_maker_->UpdatePredictionCache(data, p_out_preds);
}
}
char const* Name() const override {
return "grow_gpu_hist";
}
bool HasNodePosition() const override { return true; }
private:
GPUHistMakerTrainParam hist_maker_param_;
ObjInfo task_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
};
#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([](ObjInfo task) { return new GPUHistMaker(task); });
.set_body([](GenericParameter const* tparam, ObjInfo task) {
return new GPUHistMaker(tparam, task);
});
#endif // !defined(GTEST_TEST)
} // namespace tree

View File

@@ -24,6 +24,7 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker);
class HistMaker: public BaseMaker {
public:
explicit HistMaker(GenericParameter const *ctx) : BaseMaker(ctx) {}
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
@@ -262,12 +263,10 @@ class HistMaker: public BaseMaker {
}
};
class CQHistMaker: public HistMaker {
class CQHistMaker : public HistMaker {
public:
CQHistMaker() = default;
char const* Name() const override {
return "grow_local_histmaker";
}
explicit CQHistMaker(GenericParameter const *ctx) : HistMaker(ctx) {}
char const *Name() const override { return "grow_local_histmaker"; }
protected:
struct HistEntry {
@@ -624,9 +623,7 @@ class CQHistMaker: public HistMaker {
};
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.")
.set_body([](ObjInfo) {
return new CQHistMaker();
});
.describe("Tree constructor that uses approximate histogram construction.")
.set_body([](GenericParameter const *ctx, ObjInfo) { return new CQHistMaker(ctx); });
} // namespace tree
} // namespace xgboost

View File

@@ -21,9 +21,9 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_prune);
/*! \brief pruner that prunes a tree after growing finishes */
class TreePruner: public TreeUpdater {
class TreePruner : public TreeUpdater {
public:
explicit TreePruner(ObjInfo task) {
explicit TreePruner(GenericParameter const* ctx, ObjInfo task) : TreeUpdater(ctx) {
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
pruner_monitor_.Init("TreePruner");
}
@@ -112,9 +112,7 @@ class TreePruner: public TreeUpdater {
};
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
.describe("Pruner that prune the tree according to statistics.")
.set_body([](ObjInfo task) {
return new TreePruner(task);
});
.describe("Pruner that prune the tree according to statistics.")
.set_body([](GenericParameter const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
} // namespace tree
} // namespace xgboost

View File

@@ -411,6 +411,8 @@ template struct QuantileHistMaker::Builder<double>;
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body([](ObjInfo task) { return new QuantileHistMaker(task); });
.set_body([](GenericParameter const *ctx, ObjInfo task) {
return new QuantileHistMaker(ctx, task);
});
} // namespace tree
} // namespace xgboost

View File

@@ -235,7 +235,8 @@ inline BatchParam HistBatch(TrainParam const& param) {
/*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater {
public:
explicit QuantileHistMaker(ObjInfo task) : task_{task} {}
explicit QuantileHistMaker(GenericParameter const* ctx, ObjInfo task)
: task_{task}, TreeUpdater(ctx) {}
void Configure(const Args& args) override;
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,

View File

@@ -22,11 +22,10 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_refresh);
/*! \brief pruner that prunes a tree after growing finishs */
class TreeRefresher: public TreeUpdater {
class TreeRefresher : public TreeUpdater {
public:
void Configure(const Args& args) override {
param_.UpdateAllowUnknown(args);
}
explicit TreeRefresher(GenericParameter const *ctx) : TreeUpdater(ctx) {}
void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_);
@@ -160,9 +159,7 @@ class TreeRefresher: public TreeUpdater {
};
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
.describe("Refresher that refreshes the weight and statistics according to data.")
.set_body([](ObjInfo) {
return new TreeRefresher();
});
.describe("Refresher that refreshes the weight and statistics according to data.")
.set_body([](GenericParameter const *ctx, ObjInfo) { return new TreeRefresher(ctx); });
} // namespace tree
} // namespace xgboost

View File

@@ -20,8 +20,9 @@ DMLC_REGISTRY_FILE_TAG(updater_sync);
* \brief syncher that synchronize the tree in all distributed nodes
* can implement various strategies, so far it is always set to node 0's tree
*/
class TreeSyncher: public TreeUpdater {
class TreeSyncher : public TreeUpdater {
public:
explicit TreeSyncher(GenericParameter const* tparam) : TreeUpdater(tparam) {}
void Configure(const Args&) override {}
void LoadConfig(Json const&) override {}
@@ -52,9 +53,7 @@ class TreeSyncher: public TreeUpdater {
};
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
.describe("Syncher that synchronize the tree in all distributed nodes.")
.set_body([](ObjInfo) {
return new TreeSyncher();
});
.describe("Syncher that synchronize the tree in all distributed nodes.")
.set_body([](GenericParameter const* tparam, ObjInfo) { return new TreeSyncher(tparam); });
} // namespace tree
} // namespace xgboost