Remove single_precision_histogram for gpu_hist (#7828)
This commit is contained in:
parent
50d854e02e
commit
90cce38236
@ -59,13 +59,11 @@ Supported parameters
|
|||||||
+--------------------------------+--------------+
|
+--------------------------------+--------------+
|
||||||
| ``interaction_constraints`` | |tick| |
|
| ``interaction_constraints`` | |tick| |
|
||||||
+--------------------------------+--------------+
|
+--------------------------------+--------------+
|
||||||
| ``single_precision_histogram`` | |tick| |
|
| ``single_precision_histogram`` | |cross| |
|
||||||
+--------------------------------+--------------+
|
+--------------------------------+--------------+
|
||||||
|
|
||||||
GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``.
|
GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``.
|
||||||
|
|
||||||
The experimental parameter ``single_precision_histogram`` can be set to True to enable building histograms using single precision. This may improve speed, in particular on older architectures.
|
|
||||||
|
|
||||||
The device ordinal (which GPU to use if you have many of them) can be selected using the
|
The device ordinal (which GPU to use if you have many of them) can be selected using the
|
||||||
``gpu_id`` parameter, which defaults to 0 (the first device reported by CUDA runtime).
|
``gpu_id`` parameter, which defaults to 0 (the first device reported by CUDA runtime).
|
||||||
|
|
||||||
|
|||||||
@ -240,7 +240,7 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
|
|||||||
|
|
||||||
* ``single_precision_histogram``, [default= ``false``]
|
* ``single_precision_histogram``, [default= ``false``]
|
||||||
|
|
||||||
- Use single precision to build histograms instead of double precision.
|
- Use single precision to build histograms instead of double precision. Currently disabled for ``gpu_hist``.
|
||||||
|
|
||||||
* ``max_cat_to_onehot``
|
* ``max_cat_to_onehot``
|
||||||
|
|
||||||
|
|||||||
@ -35,6 +35,7 @@ class TreeUpdater : public Configurable {
|
|||||||
GenericParameter const* ctx_ = nullptr;
|
GenericParameter const* ctx_ = nullptr;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
explicit TreeUpdater(const GenericParameter* ctx) : ctx_(ctx) {}
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
~TreeUpdater() override = default;
|
~TreeUpdater() override = default;
|
||||||
/*!
|
/*!
|
||||||
@ -98,8 +99,9 @@ class TreeUpdater : public Configurable {
|
|||||||
* \brief Registry entry for tree updater.
|
* \brief Registry entry for tree updater.
|
||||||
*/
|
*/
|
||||||
struct TreeUpdaterReg
|
struct TreeUpdaterReg
|
||||||
: public dmlc::FunctionRegEntryBase<TreeUpdaterReg,
|
: public dmlc::FunctionRegEntryBase<
|
||||||
std::function<TreeUpdater*(ObjInfo task)> > {};
|
TreeUpdaterReg,
|
||||||
|
std::function<TreeUpdater*(GenericParameter const* tparam, ObjInfo task)> > {};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Macro to register tree updater.
|
* \brief Macro to register tree updater.
|
||||||
|
|||||||
@ -20,8 +20,7 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const
|
|||||||
if (e == nullptr) {
|
if (e == nullptr) {
|
||||||
LOG(FATAL) << "Unknown tree updater " << name;
|
LOG(FATAL) << "Unknown tree updater " << name;
|
||||||
}
|
}
|
||||||
auto p_updater = (e->body)(task);
|
auto p_updater = (e->body)(tparam, task);
|
||||||
p_updater->ctx_ = tparam;
|
|
||||||
return p_updater;
|
return p_updater;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -268,7 +268,10 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
ObjInfo task_;
|
ObjInfo task_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); }
|
explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task)
|
||||||
|
: task_{task}, TreeUpdater(ctx) {
|
||||||
|
monitor_.Init(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
void Configure(const Args &args) override {
|
void Configure(const Args &args) override {
|
||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
@ -365,6 +368,8 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
|
|||||||
.describe(
|
.describe(
|
||||||
"Tree constructor that uses approximate histogram construction "
|
"Tree constructor that uses approximate histogram construction "
|
||||||
"for each node.")
|
"for each node.")
|
||||||
.set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); });
|
.set_body([](GenericParameter const *ctx, ObjInfo task) {
|
||||||
|
return new GlobalApproxUpdater(ctx, task);
|
||||||
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -33,11 +33,10 @@ namespace tree {
|
|||||||
* \brief base tree maker class that defines common operation
|
* \brief base tree maker class that defines common operation
|
||||||
* needed in tree making
|
* needed in tree making
|
||||||
*/
|
*/
|
||||||
class BaseMaker: public TreeUpdater {
|
class BaseMaker : public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
void Configure(const Args& args) override {
|
explicit BaseMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {}
|
||||||
param_.UpdateAllowUnknown(args);
|
void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
|
||||||
}
|
|
||||||
|
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
auto const& config = get<Object const>(in);
|
auto const& config = get<Object const>(in);
|
||||||
|
|||||||
@ -57,7 +57,8 @@ DMLC_REGISTER_PARAMETER(ColMakerTrainParam);
|
|||||||
/*! \brief column-wise update to construct a tree */
|
/*! \brief column-wise update to construct a tree */
|
||||||
class ColMaker: public TreeUpdater {
|
class ColMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
void Configure(const Args& args) override {
|
explicit ColMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {}
|
||||||
|
void Configure(const Args &args) override {
|
||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
colmaker_param_.UpdateAllowUnknown(args);
|
colmaker_param_.UpdateAllowUnknown(args);
|
||||||
}
|
}
|
||||||
@ -614,8 +615,8 @@ class ColMaker: public TreeUpdater {
|
|||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||||
.describe("Grow tree with parallelization over columns.")
|
.describe("Grow tree with parallelization over columns.")
|
||||||
.set_body([](ObjInfo) {
|
.set_body([](GenericParameter const* ctx, ObjInfo) {
|
||||||
return new ColMaker();
|
return new ColMaker(ctx);
|
||||||
});
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -50,12 +50,9 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
|||||||
// training parameters specific to this algorithm
|
// training parameters specific to this algorithm
|
||||||
struct GPUHistMakerTrainParam
|
struct GPUHistMakerTrainParam
|
||||||
: public XGBoostParameter<GPUHistMakerTrainParam> {
|
: public XGBoostParameter<GPUHistMakerTrainParam> {
|
||||||
bool single_precision_histogram;
|
|
||||||
bool debug_synchronize;
|
bool debug_synchronize;
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
|
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
|
||||||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
|
||||||
"Use single precision to build histograms.");
|
|
||||||
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
|
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
|
||||||
"Check if all distributed tree are identical after tree construction.");
|
"Check if all distributed tree are identical after tree construction.");
|
||||||
}
|
}
|
||||||
@ -557,6 +554,13 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
|
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
|
||||||
RegTree& tree = *p_tree;
|
RegTree& tree = *p_tree;
|
||||||
|
|
||||||
|
// Sanity check - have we created a leaf with no training instances?
|
||||||
|
if (!rabit::IsDistributed() && row_partitioner) {
|
||||||
|
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
|
||||||
|
<< "No training instances in this leaf!";
|
||||||
|
}
|
||||||
|
|
||||||
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
|
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
|
||||||
auto base_weight = candidate.base_weight;
|
auto base_weight = candidate.base_weight;
|
||||||
auto left_weight = candidate.left_weight * param.learning_rate;
|
auto left_weight = candidate.left_weight * param.learning_rate;
|
||||||
@ -702,26 +706,42 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename GradientSumT>
|
class GPUHistMaker : public TreeUpdater {
|
||||||
class GPUHistMakerSpecialised {
|
using GradientSumT = GradientPairPrecise;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {};
|
explicit GPUHistMaker(GenericParameter const* ctx, ObjInfo task)
|
||||||
void Configure(const Args& args, GenericParameter const* generic_param) {
|
: TreeUpdater(ctx), task_{task} {};
|
||||||
|
void Configure(const Args& args) override {
|
||||||
|
// Used in test to count how many configurations are performed
|
||||||
|
LOG(DEBUG) << "[GPU Hist]: Configure";
|
||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
ctx_ = generic_param;
|
|
||||||
hist_maker_param_.UpdateAllowUnknown(args);
|
hist_maker_param_.UpdateAllowUnknown(args);
|
||||||
dh::CheckComputeCapability();
|
dh::CheckComputeCapability();
|
||||||
|
initialised_ = false;
|
||||||
|
|
||||||
monitor_.Init("updater_gpu_hist");
|
monitor_.Init("updater_gpu_hist");
|
||||||
}
|
}
|
||||||
|
|
||||||
~GPUHistMakerSpecialised() { // NOLINT
|
void LoadConfig(Json const& in) override {
|
||||||
|
auto const& config = get<Object const>(in);
|
||||||
|
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||||
|
initialised_ = false;
|
||||||
|
FromJson(config.at("train_param"), ¶m_);
|
||||||
|
}
|
||||||
|
void SaveConfig(Json* p_out) const override {
|
||||||
|
auto& out = *p_out;
|
||||||
|
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
|
||||||
|
out["train_param"] = ToJson(param_);
|
||||||
|
}
|
||||||
|
|
||||||
|
~GPUHistMaker() { // NOLINT
|
||||||
dh::GlobalMemoryLogger().Log();
|
dh::GlobalMemoryLogger().Log();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
const std::vector<RegTree*>& trees) {
|
const std::vector<RegTree*>& trees) override {
|
||||||
monitor_.Start("Update");
|
monitor_.Start("Update");
|
||||||
|
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
@ -791,7 +811,7 @@ class GPUHistMakerSpecialised {
|
|||||||
}
|
}
|
||||||
fs.Seek(0);
|
fs.Seek(0);
|
||||||
rabit::Broadcast(&s_model, 0);
|
rabit::Broadcast(&s_model, 0);
|
||||||
RegTree reference_tree {}; // rank 0 tree
|
RegTree reference_tree{}; // rank 0 tree
|
||||||
reference_tree.Load(&fs);
|
reference_tree.Load(&fs);
|
||||||
CHECK(*local_tree == reference_tree);
|
CHECK(*local_tree == reference_tree);
|
||||||
}
|
}
|
||||||
@ -806,8 +826,8 @@ class GPUHistMakerSpecialised {
|
|||||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position);
|
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix *data,
|
bool UpdatePredictionCache(const DMatrix* data,
|
||||||
linalg::VectorView<bst_float> p_out_preds) {
|
linalg::VectorView<bst_float> p_out_preds) override {
|
||||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -817,109 +837,33 @@ class GPUHistMakerSpecialised {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
TrainParam param_; // NOLINT
|
TrainParam param_; // NOLINT
|
||||||
MetaInfo* info_{}; // NOLINT
|
MetaInfo* info_{}; // NOLINT
|
||||||
|
|
||||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||||
|
|
||||||
|
char const* Name() const override { return "grow_gpu_hist"; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool initialised_ { false };
|
bool initialised_{false};
|
||||||
|
|
||||||
GPUHistMakerTrainParam hist_maker_param_;
|
GPUHistMakerTrainParam hist_maker_param_;
|
||||||
Context const* ctx_;
|
|
||||||
|
|
||||||
dh::AllReducer reducer_;
|
dh::AllReducer reducer_;
|
||||||
|
|
||||||
DMatrix* p_last_fmat_ { nullptr };
|
DMatrix* p_last_fmat_{nullptr};
|
||||||
RegTree const* p_last_tree_{nullptr};
|
RegTree const* p_last_tree_{nullptr};
|
||||||
ObjInfo task_;
|
ObjInfo task_;
|
||||||
|
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GPUHistMaker : public TreeUpdater {
|
|
||||||
public:
|
|
||||||
explicit GPUHistMaker(ObjInfo task) : task_{task} {}
|
|
||||||
void Configure(const Args& args) override {
|
|
||||||
// Used in test to count how many configurations are performed
|
|
||||||
LOG(DEBUG) << "[GPU Hist]: Configure";
|
|
||||||
hist_maker_param_.UpdateAllowUnknown(args);
|
|
||||||
// The passed in args can be empty, if we simply purge the old maker without
|
|
||||||
// preserving parameters then we can't do Update on it.
|
|
||||||
TrainParam param;
|
|
||||||
if (float_maker_) {
|
|
||||||
param = float_maker_->param_;
|
|
||||||
} else if (double_maker_) {
|
|
||||||
param = double_maker_->param_;
|
|
||||||
}
|
|
||||||
if (hist_maker_param_.single_precision_histogram) {
|
|
||||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
|
|
||||||
float_maker_->param_ = param;
|
|
||||||
float_maker_->Configure(args, ctx_);
|
|
||||||
} else {
|
|
||||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
|
|
||||||
double_maker_->param_ = param;
|
|
||||||
double_maker_->Configure(args, ctx_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void LoadConfig(Json const& in) override {
|
|
||||||
auto const& config = get<Object const>(in);
|
|
||||||
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
|
||||||
if (hist_maker_param_.single_precision_histogram) {
|
|
||||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
|
|
||||||
FromJson(config.at("train_param"), &float_maker_->param_);
|
|
||||||
} else {
|
|
||||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
|
|
||||||
FromJson(config.at("train_param"), &double_maker_->param_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void SaveConfig(Json* p_out) const override {
|
|
||||||
auto& out = *p_out;
|
|
||||||
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
|
|
||||||
if (hist_maker_param_.single_precision_histogram) {
|
|
||||||
out["train_param"] = ToJson(float_maker_->param_);
|
|
||||||
} else {
|
|
||||||
out["train_param"] = ToJson(double_maker_->param_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
|
||||||
const std::vector<RegTree*>& trees) override {
|
|
||||||
if (hist_maker_param_.single_precision_histogram) {
|
|
||||||
float_maker_->Update(gpair, dmat, out_position, trees);
|
|
||||||
} else {
|
|
||||||
double_maker_->Update(gpair, dmat, out_position, trees);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix* data,
|
|
||||||
linalg::VectorView<bst_float> p_out_preds) override {
|
|
||||||
if (hist_maker_param_.single_precision_histogram) {
|
|
||||||
return float_maker_->UpdatePredictionCache(data, p_out_preds);
|
|
||||||
} else {
|
|
||||||
return double_maker_->UpdatePredictionCache(data, p_out_preds);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
char const* Name() const override {
|
|
||||||
return "grow_gpu_hist";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HasNodePosition() const override { return true; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
GPUHistMakerTrainParam hist_maker_param_;
|
|
||||||
ObjInfo task_;
|
|
||||||
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
|
|
||||||
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#if !defined(GTEST_TEST)
|
#if !defined(GTEST_TEST)
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||||
.describe("Grow tree with GPU.")
|
.describe("Grow tree with GPU.")
|
||||||
.set_body([](ObjInfo task) { return new GPUHistMaker(task); });
|
.set_body([](GenericParameter const* tparam, ObjInfo task) {
|
||||||
|
return new GPUHistMaker(tparam, task);
|
||||||
|
});
|
||||||
#endif // !defined(GTEST_TEST)
|
#endif // !defined(GTEST_TEST)
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -24,6 +24,7 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker);
|
|||||||
|
|
||||||
class HistMaker: public BaseMaker {
|
class HistMaker: public BaseMaker {
|
||||||
public:
|
public:
|
||||||
|
explicit HistMaker(GenericParameter const *ctx) : BaseMaker(ctx) {}
|
||||||
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
|
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
|
||||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
const std::vector<RegTree *> &trees) override {
|
const std::vector<RegTree *> &trees) override {
|
||||||
@ -262,12 +263,10 @@ class HistMaker: public BaseMaker {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class CQHistMaker: public HistMaker {
|
class CQHistMaker : public HistMaker {
|
||||||
public:
|
public:
|
||||||
CQHistMaker() = default;
|
explicit CQHistMaker(GenericParameter const *ctx) : HistMaker(ctx) {}
|
||||||
char const* Name() const override {
|
char const *Name() const override { return "grow_local_histmaker"; }
|
||||||
return "grow_local_histmaker";
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
struct HistEntry {
|
struct HistEntry {
|
||||||
@ -624,9 +623,7 @@ class CQHistMaker: public HistMaker {
|
|||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
||||||
.describe("Tree constructor that uses approximate histogram construction.")
|
.describe("Tree constructor that uses approximate histogram construction.")
|
||||||
.set_body([](ObjInfo) {
|
.set_body([](GenericParameter const *ctx, ObjInfo) { return new CQHistMaker(ctx); });
|
||||||
return new CQHistMaker();
|
|
||||||
});
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -21,9 +21,9 @@ namespace tree {
|
|||||||
DMLC_REGISTRY_FILE_TAG(updater_prune);
|
DMLC_REGISTRY_FILE_TAG(updater_prune);
|
||||||
|
|
||||||
/*! \brief pruner that prunes a tree after growing finishes */
|
/*! \brief pruner that prunes a tree after growing finishes */
|
||||||
class TreePruner: public TreeUpdater {
|
class TreePruner : public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
explicit TreePruner(ObjInfo task) {
|
explicit TreePruner(GenericParameter const* ctx, ObjInfo task) : TreeUpdater(ctx) {
|
||||||
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
|
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
|
||||||
pruner_monitor_.Init("TreePruner");
|
pruner_monitor_.Init("TreePruner");
|
||||||
}
|
}
|
||||||
@ -112,9 +112,7 @@ class TreePruner: public TreeUpdater {
|
|||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
||||||
.describe("Pruner that prune the tree according to statistics.")
|
.describe("Pruner that prune the tree according to statistics.")
|
||||||
.set_body([](ObjInfo task) {
|
.set_body([](GenericParameter const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
|
||||||
return new TreePruner(task);
|
|
||||||
});
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -411,6 +411,8 @@ template struct QuantileHistMaker::Builder<double>;
|
|||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
||||||
.describe("Grow tree using quantized histogram.")
|
.describe("Grow tree using quantized histogram.")
|
||||||
.set_body([](ObjInfo task) { return new QuantileHistMaker(task); });
|
.set_body([](GenericParameter const *ctx, ObjInfo task) {
|
||||||
|
return new QuantileHistMaker(ctx, task);
|
||||||
|
});
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -235,7 +235,8 @@ inline BatchParam HistBatch(TrainParam const& param) {
|
|||||||
/*! \brief construct a tree using quantized feature values */
|
/*! \brief construct a tree using quantized feature values */
|
||||||
class QuantileHistMaker: public TreeUpdater {
|
class QuantileHistMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
explicit QuantileHistMaker(ObjInfo task) : task_{task} {}
|
explicit QuantileHistMaker(GenericParameter const* ctx, ObjInfo task)
|
||||||
|
: task_{task}, TreeUpdater(ctx) {}
|
||||||
void Configure(const Args& args) override;
|
void Configure(const Args& args) override;
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||||
|
|||||||
@ -22,11 +22,10 @@ namespace tree {
|
|||||||
DMLC_REGISTRY_FILE_TAG(updater_refresh);
|
DMLC_REGISTRY_FILE_TAG(updater_refresh);
|
||||||
|
|
||||||
/*! \brief pruner that prunes a tree after growing finishs */
|
/*! \brief pruner that prunes a tree after growing finishs */
|
||||||
class TreeRefresher: public TreeUpdater {
|
class TreeRefresher : public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
void Configure(const Args& args) override {
|
explicit TreeRefresher(GenericParameter const *ctx) : TreeUpdater(ctx) {}
|
||||||
param_.UpdateAllowUnknown(args);
|
void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
|
||||||
}
|
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
auto const& config = get<Object const>(in);
|
auto const& config = get<Object const>(in);
|
||||||
FromJson(config.at("train_param"), &this->param_);
|
FromJson(config.at("train_param"), &this->param_);
|
||||||
@ -160,9 +159,7 @@ class TreeRefresher: public TreeUpdater {
|
|||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
||||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||||
.set_body([](ObjInfo) {
|
.set_body([](GenericParameter const *ctx, ObjInfo) { return new TreeRefresher(ctx); });
|
||||||
return new TreeRefresher();
|
|
||||||
});
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -20,8 +20,9 @@ DMLC_REGISTRY_FILE_TAG(updater_sync);
|
|||||||
* \brief syncher that synchronize the tree in all distributed nodes
|
* \brief syncher that synchronize the tree in all distributed nodes
|
||||||
* can implement various strategies, so far it is always set to node 0's tree
|
* can implement various strategies, so far it is always set to node 0's tree
|
||||||
*/
|
*/
|
||||||
class TreeSyncher: public TreeUpdater {
|
class TreeSyncher : public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
|
explicit TreeSyncher(GenericParameter const* tparam) : TreeUpdater(tparam) {}
|
||||||
void Configure(const Args&) override {}
|
void Configure(const Args&) override {}
|
||||||
|
|
||||||
void LoadConfig(Json const&) override {}
|
void LoadConfig(Json const&) override {}
|
||||||
@ -52,9 +53,7 @@ class TreeSyncher: public TreeUpdater {
|
|||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
||||||
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
||||||
.set_body([](ObjInfo) {
|
.set_body([](GenericParameter const* tparam, ObjInfo) { return new TreeSyncher(tparam); });
|
||||||
return new TreeSyncher();
|
|
||||||
});
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -277,8 +277,10 @@ void TestHistogramIndexImpl() {
|
|||||||
int constexpr kNRows = 1000, kNCols = 10;
|
int constexpr kNRows = 1000, kNCols = 10;
|
||||||
|
|
||||||
// Build 2 matrices and build a histogram maker with that
|
// Build 2 matrices and build a histogram maker with that
|
||||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}},
|
|
||||||
hist_maker_ext{ObjInfo{ObjInfo::kRegression}};
|
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||||
|
tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}},
|
||||||
|
hist_maker_ext{&generic_param,ObjInfo{ObjInfo::kRegression}};
|
||||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
std::unique_ptr<DMatrix> hist_maker_dmat(
|
||||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
||||||
|
|
||||||
@ -291,10 +293,9 @@ void TestHistogramIndexImpl() {
|
|||||||
{"max_leaves", "0"}
|
{"max_leaves", "0"}
|
||||||
};
|
};
|
||||||
|
|
||||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
hist_maker.Configure(training_params);
|
||||||
hist_maker.Configure(training_params, &generic_param);
|
|
||||||
hist_maker.InitDataOnce(hist_maker_dmat.get());
|
hist_maker.InitDataOnce(hist_maker_dmat.get());
|
||||||
hist_maker_ext.Configure(training_params, &generic_param);
|
hist_maker_ext.Configure(training_params);
|
||||||
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
|
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
|
||||||
|
|
||||||
// Extract the device maker from the histogram makers and from that its compressed
|
// Extract the device maker from the histogram makers and from that its compressed
|
||||||
@ -346,9 +347,9 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
|||||||
{"sampling_method", sampling_method},
|
{"sampling_method", sampling_method},
|
||||||
};
|
};
|
||||||
|
|
||||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
|
|
||||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||||
hist_maker.Configure(args, &generic_param);
|
tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}};
|
||||||
|
hist_maker.Configure(args);
|
||||||
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
hist_maker.Update(gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, {tree});
|
hist_maker.Update(gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, {tree});
|
||||||
|
|||||||
@ -16,11 +16,11 @@ class TestGPUBasicModels:
|
|||||||
cpu_test_bm = test_bm.TestModels()
|
cpu_test_bm = test_bm.TestModels()
|
||||||
|
|
||||||
def run_cls(self, X, y):
|
def run_cls(self, X, y):
|
||||||
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True)
|
cls = xgb.XGBClassifier(tree_method='gpu_hist')
|
||||||
cls.fit(X, y)
|
cls.fit(X, y)
|
||||||
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
|
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
|
||||||
|
|
||||||
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True)
|
cls = xgb.XGBClassifier(tree_method='gpu_hist')
|
||||||
cls.fit(X, y)
|
cls.fit(X, y)
|
||||||
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
|
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,6 @@ parameter_strategy = strategies.fixed_dictionaries({
|
|||||||
'max_leaves': strategies.integers(0, 256),
|
'max_leaves': strategies.integers(0, 256),
|
||||||
'max_bin': strategies.integers(2, 1024),
|
'max_bin': strategies.integers(2, 1024),
|
||||||
'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']),
|
'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']),
|
||||||
'single_precision_histogram': strategies.booleans(),
|
|
||||||
'min_child_weight': strategies.floats(0.5, 2.0),
|
'min_child_weight': strategies.floats(0.5, 2.0),
|
||||||
'seed': strategies.integers(0, 10),
|
'seed': strategies.integers(0, 10),
|
||||||
# We cannot enable subsampling as the training loss can increase
|
# We cannot enable subsampling as the training loss can increase
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user