Remove single_precision_histogram for gpu_hist (#7828)

This commit is contained in:
Rory Mitchell 2022-05-03 14:53:19 +02:00 committed by GitHub
parent 50d854e02e
commit 90cce38236
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 97 additions and 155 deletions

View File

@ -59,13 +59,11 @@ Supported parameters
+--------------------------------+--------------+ +--------------------------------+--------------+
| ``interaction_constraints`` | |tick| | | ``interaction_constraints`` | |tick| |
+--------------------------------+--------------+ +--------------------------------+--------------+
| ``single_precision_histogram`` | |tick| | | ``single_precision_histogram`` | |cross| |
+--------------------------------+--------------+ +--------------------------------+--------------+
GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``. GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``.
The experimental parameter ``single_precision_histogram`` can be set to True to enable building histograms using single precision. This may improve speed, in particular on older architectures.
The device ordinal (which GPU to use if you have many of them) can be selected using the The device ordinal (which GPU to use if you have many of them) can be selected using the
``gpu_id`` parameter, which defaults to 0 (the first device reported by CUDA runtime). ``gpu_id`` parameter, which defaults to 0 (the first device reported by CUDA runtime).

View File

@ -240,7 +240,7 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
* ``single_precision_histogram``, [default= ``false``] * ``single_precision_histogram``, [default= ``false``]
- Use single precision to build histograms instead of double precision. - Use single precision to build histograms instead of double precision. Currently disabled for ``gpu_hist``.
* ``max_cat_to_onehot`` * ``max_cat_to_onehot``

View File

@ -35,6 +35,7 @@ class TreeUpdater : public Configurable {
GenericParameter const* ctx_ = nullptr; GenericParameter const* ctx_ = nullptr;
public: public:
explicit TreeUpdater(const GenericParameter* ctx) : ctx_(ctx) {}
/*! \brief virtual destructor */ /*! \brief virtual destructor */
~TreeUpdater() override = default; ~TreeUpdater() override = default;
/*! /*!
@ -98,8 +99,9 @@ class TreeUpdater : public Configurable {
* \brief Registry entry for tree updater. * \brief Registry entry for tree updater.
*/ */
struct TreeUpdaterReg struct TreeUpdaterReg
: public dmlc::FunctionRegEntryBase<TreeUpdaterReg, : public dmlc::FunctionRegEntryBase<
std::function<TreeUpdater*(ObjInfo task)> > {}; TreeUpdaterReg,
std::function<TreeUpdater*(GenericParameter const* tparam, ObjInfo task)> > {};
/*! /*!
* \brief Macro to register tree updater. * \brief Macro to register tree updater.

View File

@ -20,8 +20,7 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const
if (e == nullptr) { if (e == nullptr) {
LOG(FATAL) << "Unknown tree updater " << name; LOG(FATAL) << "Unknown tree updater " << name;
} }
auto p_updater = (e->body)(task); auto p_updater = (e->body)(tparam, task);
p_updater->ctx_ = tparam;
return p_updater; return p_updater;
} }

View File

@ -268,7 +268,10 @@ class GlobalApproxUpdater : public TreeUpdater {
ObjInfo task_; ObjInfo task_;
public: public:
explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); } explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task)
: task_{task}, TreeUpdater(ctx) {
monitor_.Init(__func__);
}
void Configure(const Args &args) override { void Configure(const Args &args) override {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
@ -365,6 +368,8 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
.describe( .describe(
"Tree constructor that uses approximate histogram construction " "Tree constructor that uses approximate histogram construction "
"for each node.") "for each node.")
.set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); }); .set_body([](GenericParameter const *ctx, ObjInfo task) {
return new GlobalApproxUpdater(ctx, task);
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -35,9 +35,8 @@ namespace tree {
*/ */
class BaseMaker : public TreeUpdater { class BaseMaker : public TreeUpdater {
public: public:
void Configure(const Args& args) override { explicit BaseMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {}
param_.UpdateAllowUnknown(args); void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
}
void LoadConfig(Json const& in) override { void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in); auto const& config = get<Object const>(in);

View File

@ -57,6 +57,7 @@ DMLC_REGISTER_PARAMETER(ColMakerTrainParam);
/*! \brief column-wise update to construct a tree */ /*! \brief column-wise update to construct a tree */
class ColMaker: public TreeUpdater { class ColMaker: public TreeUpdater {
public: public:
explicit ColMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {}
void Configure(const Args &args) override { void Configure(const Args &args) override {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
colmaker_param_.UpdateAllowUnknown(args); colmaker_param_.UpdateAllowUnknown(args);
@ -614,8 +615,8 @@ class ColMaker: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
.describe("Grow tree with parallelization over columns.") .describe("Grow tree with parallelization over columns.")
.set_body([](ObjInfo) { .set_body([](GenericParameter const* ctx, ObjInfo) {
return new ColMaker(); return new ColMaker(ctx);
}); });
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -50,12 +50,9 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
// training parameters specific to this algorithm // training parameters specific to this algorithm
struct GPUHistMakerTrainParam struct GPUHistMakerTrainParam
: public XGBoostParameter<GPUHistMakerTrainParam> { : public XGBoostParameter<GPUHistMakerTrainParam> {
bool single_precision_histogram;
bool debug_synchronize; bool debug_synchronize;
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
"Check if all distributed tree are identical after tree construction."); "Check if all distributed tree are identical after tree construction.");
} }
@ -557,6 +554,13 @@ struct GPUHistMakerDevice {
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) { void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree; RegTree& tree = *p_tree;
// Sanity check - have we created a leaf with no training instances?
if (!rabit::IsDistributed() && row_partitioner) {
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
<< "No training instances in this leaf!";
}
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum; auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
auto base_weight = candidate.base_weight; auto base_weight = candidate.base_weight;
auto left_weight = candidate.left_weight * param.learning_rate; auto left_weight = candidate.left_weight * param.learning_rate;
@ -702,26 +706,42 @@ struct GPUHistMakerDevice {
} }
}; };
template <typename GradientSumT> class GPUHistMaker : public TreeUpdater {
class GPUHistMakerSpecialised { using GradientSumT = GradientPairPrecise;
public: public:
explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {}; explicit GPUHistMaker(GenericParameter const* ctx, ObjInfo task)
void Configure(const Args& args, GenericParameter const* generic_param) { : TreeUpdater(ctx), task_{task} {};
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
ctx_ = generic_param;
hist_maker_param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args);
dh::CheckComputeCapability(); dh::CheckComputeCapability();
initialised_ = false;
monitor_.Init("updater_gpu_hist"); monitor_.Init("updater_gpu_hist");
} }
~GPUHistMakerSpecialised() { // NOLINT void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
initialised_ = false;
FromJson(config.at("train_param"), &param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
out["train_param"] = ToJson(param_);
}
~GPUHistMaker() { // NOLINT
dh::GlobalMemoryLogger().Log(); dh::GlobalMemoryLogger().Log();
} }
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) { const std::vector<RegTree*>& trees) override {
monitor_.Start("Update"); monitor_.Start("Update");
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
@ -807,7 +827,7 @@ class GPUHistMakerSpecialised {
} }
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix* data,
linalg::VectorView<bst_float> p_out_preds) { linalg::VectorView<bst_float> p_out_preds) override {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false; return false;
} }
@ -822,11 +842,12 @@ class GPUHistMakerSpecialised {
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
char const* Name() const override { return "grow_gpu_hist"; }
private: private:
bool initialised_{false}; bool initialised_{false};
GPUHistMakerTrainParam hist_maker_param_; GPUHistMakerTrainParam hist_maker_param_;
Context const* ctx_;
dh::AllReducer reducer_; dh::AllReducer reducer_;
@ -837,89 +858,12 @@ class GPUHistMakerSpecialised {
common::Monitor monitor_; common::Monitor monitor_;
}; };
class GPUHistMaker : public TreeUpdater {
public:
explicit GPUHistMaker(ObjInfo task) : task_{task} {}
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
hist_maker_param_.UpdateAllowUnknown(args);
// The passed in args can be empty, if we simply purge the old maker without
// preserving parameters then we can't do Update on it.
TrainParam param;
if (float_maker_) {
param = float_maker_->param_;
} else if (double_maker_) {
param = double_maker_->param_;
}
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
float_maker_->param_ = param;
float_maker_->Configure(args, ctx_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
double_maker_->param_ = param;
double_maker_->Configure(args, ctx_);
}
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
FromJson(config.at("train_param"), &float_maker_->param_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
FromJson(config.at("train_param"), &double_maker_->param_);
}
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
out["train_param"] = ToJson(float_maker_->param_);
} else {
out["train_param"] = ToJson(double_maker_->param_);
}
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override {
if (hist_maker_param_.single_precision_histogram) {
float_maker_->Update(gpair, dmat, out_position, trees);
} else {
double_maker_->Update(gpair, dmat, out_position, trees);
}
}
bool UpdatePredictionCache(const DMatrix* data,
linalg::VectorView<bst_float> p_out_preds) override {
if (hist_maker_param_.single_precision_histogram) {
return float_maker_->UpdatePredictionCache(data, p_out_preds);
} else {
return double_maker_->UpdatePredictionCache(data, p_out_preds);
}
}
char const* Name() const override {
return "grow_gpu_hist";
}
bool HasNodePosition() const override { return true; }
private:
GPUHistMakerTrainParam hist_maker_param_;
ObjInfo task_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
};
#if !defined(GTEST_TEST) #if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.") .describe("Grow tree with GPU.")
.set_body([](ObjInfo task) { return new GPUHistMaker(task); }); .set_body([](GenericParameter const* tparam, ObjInfo task) {
return new GPUHistMaker(tparam, task);
});
#endif // !defined(GTEST_TEST) #endif // !defined(GTEST_TEST)
} // namespace tree } // namespace tree

View File

@ -24,6 +24,7 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker);
class HistMaker: public BaseMaker { class HistMaker: public BaseMaker {
public: public:
explicit HistMaker(GenericParameter const *ctx) : BaseMaker(ctx) {}
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat, void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override { const std::vector<RegTree *> &trees) override {
@ -264,10 +265,8 @@ class HistMaker: public BaseMaker {
class CQHistMaker : public HistMaker { class CQHistMaker : public HistMaker {
public: public:
CQHistMaker() = default; explicit CQHistMaker(GenericParameter const *ctx) : HistMaker(ctx) {}
char const* Name() const override { char const *Name() const override { return "grow_local_histmaker"; }
return "grow_local_histmaker";
}
protected: protected:
struct HistEntry { struct HistEntry {
@ -625,8 +624,6 @@ class CQHistMaker: public HistMaker {
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.") .describe("Tree constructor that uses approximate histogram construction.")
.set_body([](ObjInfo) { .set_body([](GenericParameter const *ctx, ObjInfo) { return new CQHistMaker(ctx); });
return new CQHistMaker();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -23,7 +23,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
/*! \brief pruner that prunes a tree after growing finishes */ /*! \brief pruner that prunes a tree after growing finishes */
class TreePruner : public TreeUpdater { class TreePruner : public TreeUpdater {
public: public:
explicit TreePruner(ObjInfo task) { explicit TreePruner(GenericParameter const* ctx, ObjInfo task) : TreeUpdater(ctx) {
syncher_.reset(TreeUpdater::Create("sync", ctx_, task)); syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
pruner_monitor_.Init("TreePruner"); pruner_monitor_.Init("TreePruner");
} }
@ -113,8 +113,6 @@ class TreePruner: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
.describe("Pruner that prune the tree according to statistics.") .describe("Pruner that prune the tree according to statistics.")
.set_body([](ObjInfo task) { .set_body([](GenericParameter const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
return new TreePruner(task);
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -411,6 +411,8 @@ template struct QuantileHistMaker::Builder<double>;
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.") .describe("Grow tree using quantized histogram.")
.set_body([](ObjInfo task) { return new QuantileHistMaker(task); }); .set_body([](GenericParameter const *ctx, ObjInfo task) {
return new QuantileHistMaker(ctx, task);
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -235,7 +235,8 @@ inline BatchParam HistBatch(TrainParam const& param) {
/*! \brief construct a tree using quantized feature values */ /*! \brief construct a tree using quantized feature values */
class QuantileHistMaker: public TreeUpdater { class QuantileHistMaker: public TreeUpdater {
public: public:
explicit QuantileHistMaker(ObjInfo task) : task_{task} {} explicit QuantileHistMaker(GenericParameter const* ctx, ObjInfo task)
: task_{task}, TreeUpdater(ctx) {}
void Configure(const Args& args) override; void Configure(const Args& args) override;
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,

View File

@ -24,9 +24,8 @@ DMLC_REGISTRY_FILE_TAG(updater_refresh);
/*! \brief pruner that prunes a tree after growing finishs */ /*! \brief pruner that prunes a tree after growing finishs */
class TreeRefresher : public TreeUpdater { class TreeRefresher : public TreeUpdater {
public: public:
void Configure(const Args& args) override { explicit TreeRefresher(GenericParameter const *ctx) : TreeUpdater(ctx) {}
param_.UpdateAllowUnknown(args); void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
}
void LoadConfig(Json const& in) override { void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in); auto const& config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_); FromJson(config.at("train_param"), &this->param_);
@ -161,8 +160,6 @@ class TreeRefresher: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
.describe("Refresher that refreshes the weight and statistics according to data.") .describe("Refresher that refreshes the weight and statistics according to data.")
.set_body([](ObjInfo) { .set_body([](GenericParameter const *ctx, ObjInfo) { return new TreeRefresher(ctx); });
return new TreeRefresher();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -22,6 +22,7 @@ DMLC_REGISTRY_FILE_TAG(updater_sync);
*/ */
class TreeSyncher : public TreeUpdater { class TreeSyncher : public TreeUpdater {
public: public:
explicit TreeSyncher(GenericParameter const* tparam) : TreeUpdater(tparam) {}
void Configure(const Args&) override {} void Configure(const Args&) override {}
void LoadConfig(Json const&) override {} void LoadConfig(Json const&) override {}
@ -53,8 +54,6 @@ class TreeSyncher: public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
.describe("Syncher that synchronize the tree in all distributed nodes.") .describe("Syncher that synchronize the tree in all distributed nodes.")
.set_body([](ObjInfo) { .set_body([](GenericParameter const* tparam, ObjInfo) { return new TreeSyncher(tparam); });
return new TreeSyncher();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -277,8 +277,10 @@ void TestHistogramIndexImpl() {
int constexpr kNRows = 1000, kNCols = 10; int constexpr kNRows = 1000, kNCols = 10;
// Build 2 matrices and build a histogram maker with that // Build 2 matrices and build a histogram maker with that
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}},
hist_maker_ext{ObjInfo{ObjInfo::kRegression}}; GenericParameter generic_param(CreateEmptyGenericParam(0));
tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}},
hist_maker_ext{&generic_param,ObjInfo{ObjInfo::kRegression}};
std::unique_ptr<DMatrix> hist_maker_dmat( std::unique_ptr<DMatrix> hist_maker_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
@ -291,10 +293,9 @@ void TestHistogramIndexImpl() {
{"max_leaves", "0"} {"max_leaves", "0"}
}; };
GenericParameter generic_param(CreateEmptyGenericParam(0)); hist_maker.Configure(training_params);
hist_maker.Configure(training_params, &generic_param);
hist_maker.InitDataOnce(hist_maker_dmat.get()); hist_maker.InitDataOnce(hist_maker_dmat.get());
hist_maker_ext.Configure(training_params, &generic_param); hist_maker_ext.Configure(training_params);
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get()); hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
// Extract the device maker from the histogram makers and from that its compressed // Extract the device maker from the histogram makers and from that its compressed
@ -346,9 +347,9 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
{"sampling_method", sampling_method}, {"sampling_method", sampling_method},
}; };
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
GenericParameter generic_param(CreateEmptyGenericParam(0)); GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(args, &generic_param); tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}};
hist_maker.Configure(args);
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
hist_maker.Update(gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, {tree}); hist_maker.Update(gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, {tree});

View File

@ -16,11 +16,11 @@ class TestGPUBasicModels:
cpu_test_bm = test_bm.TestModels() cpu_test_bm = test_bm.TestModels()
def run_cls(self, X, y): def run_cls(self, X, y):
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True) cls = xgb.XGBClassifier(tree_method='gpu_hist')
cls.fit(X, y) cls.fit(X, y)
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json') cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True) cls = xgb.XGBClassifier(tree_method='gpu_hist')
cls.fit(X, y) cls.fit(X, y)
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json') cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')

View File

@ -15,7 +15,6 @@ parameter_strategy = strategies.fixed_dictionaries({
'max_leaves': strategies.integers(0, 256), 'max_leaves': strategies.integers(0, 256),
'max_bin': strategies.integers(2, 1024), 'max_bin': strategies.integers(2, 1024),
'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']), 'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']),
'single_precision_histogram': strategies.booleans(),
'min_child_weight': strategies.floats(0.5, 2.0), 'min_child_weight': strategies.floats(0.5, 2.0),
'seed': strategies.integers(0, 10), 'seed': strategies.integers(0, 10),
# We cannot enable subsampling as the training loss can increase # We cannot enable subsampling as the training loss can increase