Remove single_precision_histogram for gpu_hist (#7828)
This commit is contained in:
parent
50d854e02e
commit
90cce38236
@ -59,13 +59,11 @@ Supported parameters
|
||||
+--------------------------------+--------------+
|
||||
| ``interaction_constraints`` | |tick| |
|
||||
+--------------------------------+--------------+
|
||||
| ``single_precision_histogram`` | |tick| |
|
||||
| ``single_precision_histogram`` | |cross| |
|
||||
+--------------------------------+--------------+
|
||||
|
||||
GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``.
|
||||
|
||||
The experimental parameter ``single_precision_histogram`` can be set to True to enable building histograms using single precision. This may improve speed, in particular on older architectures.
|
||||
|
||||
The device ordinal (which GPU to use if you have many of them) can be selected using the
|
||||
``gpu_id`` parameter, which defaults to 0 (the first device reported by CUDA runtime).
|
||||
|
||||
|
||||
@ -240,7 +240,7 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
|
||||
|
||||
* ``single_precision_histogram``, [default= ``false``]
|
||||
|
||||
- Use single precision to build histograms instead of double precision.
|
||||
- Use single precision to build histograms instead of double precision. Currently disabled for ``gpu_hist``.
|
||||
|
||||
* ``max_cat_to_onehot``
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ class TreeUpdater : public Configurable {
|
||||
GenericParameter const* ctx_ = nullptr;
|
||||
|
||||
public:
|
||||
explicit TreeUpdater(const GenericParameter* ctx) : ctx_(ctx) {}
|
||||
/*! \brief virtual destructor */
|
||||
~TreeUpdater() override = default;
|
||||
/*!
|
||||
@ -98,8 +99,9 @@ class TreeUpdater : public Configurable {
|
||||
* \brief Registry entry for tree updater.
|
||||
*/
|
||||
struct TreeUpdaterReg
|
||||
: public dmlc::FunctionRegEntryBase<TreeUpdaterReg,
|
||||
std::function<TreeUpdater*(ObjInfo task)> > {};
|
||||
: public dmlc::FunctionRegEntryBase<
|
||||
TreeUpdaterReg,
|
||||
std::function<TreeUpdater*(GenericParameter const* tparam, ObjInfo task)> > {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register tree updater.
|
||||
|
||||
@ -20,8 +20,7 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown tree updater " << name;
|
||||
}
|
||||
auto p_updater = (e->body)(task);
|
||||
p_updater->ctx_ = tparam;
|
||||
auto p_updater = (e->body)(tparam, task);
|
||||
return p_updater;
|
||||
}
|
||||
|
||||
|
||||
@ -268,7 +268,10 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
ObjInfo task_;
|
||||
|
||||
public:
|
||||
explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); }
|
||||
explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task)
|
||||
: task_{task}, TreeUpdater(ctx) {
|
||||
monitor_.Init(__func__);
|
||||
}
|
||||
|
||||
void Configure(const Args &args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
@ -365,6 +368,8 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
|
||||
.describe(
|
||||
"Tree constructor that uses approximate histogram construction "
|
||||
"for each node.")
|
||||
.set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); });
|
||||
.set_body([](GenericParameter const *ctx, ObjInfo task) {
|
||||
return new GlobalApproxUpdater(ctx, task);
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -33,11 +33,10 @@ namespace tree {
|
||||
* \brief base tree maker class that defines common operation
|
||||
* needed in tree making
|
||||
*/
|
||||
class BaseMaker: public TreeUpdater {
|
||||
class BaseMaker : public TreeUpdater {
|
||||
public:
|
||||
void Configure(const Args& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
explicit BaseMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {}
|
||||
void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
|
||||
@ -57,7 +57,8 @@ DMLC_REGISTER_PARAMETER(ColMakerTrainParam);
|
||||
/*! \brief column-wise update to construct a tree */
|
||||
class ColMaker: public TreeUpdater {
|
||||
public:
|
||||
void Configure(const Args& args) override {
|
||||
explicit ColMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {}
|
||||
void Configure(const Args &args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
colmaker_param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
@ -614,8 +615,8 @@ class ColMaker: public TreeUpdater {
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||
.describe("Grow tree with parallelization over columns.")
|
||||
.set_body([](ObjInfo) {
|
||||
return new ColMaker();
|
||||
.set_body([](GenericParameter const* ctx, ObjInfo) {
|
||||
return new ColMaker(ctx);
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -50,12 +50,9 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
// training parameters specific to this algorithm
|
||||
struct GPUHistMakerTrainParam
|
||||
: public XGBoostParameter<GPUHistMakerTrainParam> {
|
||||
bool single_precision_histogram;
|
||||
bool debug_synchronize;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
||||
"Use single precision to build histograms.");
|
||||
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
|
||||
"Check if all distributed tree are identical after tree construction.");
|
||||
}
|
||||
@ -557,6 +554,13 @@ struct GPUHistMakerDevice {
|
||||
|
||||
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
|
||||
// Sanity check - have we created a leaf with no training instances?
|
||||
if (!rabit::IsDistributed() && row_partitioner) {
|
||||
CHECK(row_partitioner->GetRows(candidate.nid).size() > 0)
|
||||
<< "No training instances in this leaf!";
|
||||
}
|
||||
|
||||
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
|
||||
auto base_weight = candidate.base_weight;
|
||||
auto left_weight = candidate.left_weight * param.learning_rate;
|
||||
@ -702,26 +706,42 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
class GPUHistMakerSpecialised {
|
||||
class GPUHistMaker : public TreeUpdater {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
|
||||
public:
|
||||
explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {};
|
||||
void Configure(const Args& args, GenericParameter const* generic_param) {
|
||||
explicit GPUHistMaker(GenericParameter const* ctx, ObjInfo task)
|
||||
: TreeUpdater(ctx), task_{task} {};
|
||||
void Configure(const Args& args) override {
|
||||
// Used in test to count how many configurations are performed
|
||||
LOG(DEBUG) << "[GPU Hist]: Configure";
|
||||
param_.UpdateAllowUnknown(args);
|
||||
ctx_ = generic_param;
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
dh::CheckComputeCapability();
|
||||
initialised_ = false;
|
||||
|
||||
monitor_.Init("updater_gpu_hist");
|
||||
}
|
||||
|
||||
~GPUHistMakerSpecialised() { // NOLINT
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||
initialised_ = false;
|
||||
FromJson(config.at("train_param"), ¶m_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
|
||||
out["train_param"] = ToJson(param_);
|
||||
}
|
||||
|
||||
~GPUHistMaker() { // NOLINT
|
||||
dh::GlobalMemoryLogger().Log();
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor_.Start("Update");
|
||||
|
||||
// rescale learning rate according to size of trees
|
||||
@ -791,7 +811,7 @@ class GPUHistMakerSpecialised {
|
||||
}
|
||||
fs.Seek(0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
RegTree reference_tree {}; // rank 0 tree
|
||||
RegTree reference_tree{}; // rank 0 tree
|
||||
reference_tree.Load(&fs);
|
||||
CHECK(*local_tree == reference_tree);
|
||||
}
|
||||
@ -806,8 +826,8 @@ class GPUHistMakerSpecialised {
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
linalg::VectorView<bst_float> p_out_preds) {
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::VectorView<bst_float> p_out_preds) override {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
@ -817,109 +837,33 @@ class GPUHistMakerSpecialised {
|
||||
return true;
|
||||
}
|
||||
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
char const* Name() const override { return "grow_gpu_hist"; }
|
||||
|
||||
private:
|
||||
bool initialised_ { false };
|
||||
bool initialised_{false};
|
||||
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
Context const* ctx_;
|
||||
|
||||
dh::AllReducer reducer_;
|
||||
|
||||
DMatrix* p_last_fmat_ { nullptr };
|
||||
DMatrix* p_last_fmat_{nullptr};
|
||||
RegTree const* p_last_tree_{nullptr};
|
||||
ObjInfo task_;
|
||||
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
|
||||
class GPUHistMaker : public TreeUpdater {
|
||||
public:
|
||||
explicit GPUHistMaker(ObjInfo task) : task_{task} {}
|
||||
void Configure(const Args& args) override {
|
||||
// Used in test to count how many configurations are performed
|
||||
LOG(DEBUG) << "[GPU Hist]: Configure";
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
// The passed in args can be empty, if we simply purge the old maker without
|
||||
// preserving parameters then we can't do Update on it.
|
||||
TrainParam param;
|
||||
if (float_maker_) {
|
||||
param = float_maker_->param_;
|
||||
} else if (double_maker_) {
|
||||
param = double_maker_->param_;
|
||||
}
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
|
||||
float_maker_->param_ = param;
|
||||
float_maker_->Configure(args, ctx_);
|
||||
} else {
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
|
||||
double_maker_->param_ = param;
|
||||
double_maker_->Configure(args, ctx_);
|
||||
}
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
|
||||
FromJson(config.at("train_param"), &float_maker_->param_);
|
||||
} else {
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
|
||||
FromJson(config.at("train_param"), &double_maker_->param_);
|
||||
}
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
out["train_param"] = ToJson(float_maker_->param_);
|
||||
} else {
|
||||
out["train_param"] = ToJson(double_maker_->param_);
|
||||
}
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
float_maker_->Update(gpair, dmat, out_position, trees);
|
||||
} else {
|
||||
double_maker_->Update(gpair, dmat, out_position, trees);
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::VectorView<bst_float> p_out_preds) override {
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
return float_maker_->UpdatePredictionCache(data, p_out_preds);
|
||||
} else {
|
||||
return double_maker_->UpdatePredictionCache(data, p_out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
char const* Name() const override {
|
||||
return "grow_gpu_hist";
|
||||
}
|
||||
|
||||
bool HasNodePosition() const override { return true; }
|
||||
|
||||
private:
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
ObjInfo task_;
|
||||
std::unique_ptr<GPUHistMakerSpecialised<GradientPair>> float_maker_;
|
||||
std::unique_ptr<GPUHistMakerSpecialised<GradientPairPrecise>> double_maker_;
|
||||
};
|
||||
|
||||
#if !defined(GTEST_TEST)
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||
.describe("Grow tree with GPU.")
|
||||
.set_body([](ObjInfo task) { return new GPUHistMaker(task); });
|
||||
.set_body([](GenericParameter const* tparam, ObjInfo task) {
|
||||
return new GPUHistMaker(tparam, task);
|
||||
});
|
||||
#endif // !defined(GTEST_TEST)
|
||||
|
||||
} // namespace tree
|
||||
|
||||
@ -24,6 +24,7 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker);
|
||||
|
||||
class HistMaker: public BaseMaker {
|
||||
public:
|
||||
explicit HistMaker(GenericParameter const *ctx) : BaseMaker(ctx) {}
|
||||
void Update(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
@ -262,12 +263,10 @@ class HistMaker: public BaseMaker {
|
||||
}
|
||||
};
|
||||
|
||||
class CQHistMaker: public HistMaker {
|
||||
class CQHistMaker : public HistMaker {
|
||||
public:
|
||||
CQHistMaker() = default;
|
||||
char const* Name() const override {
|
||||
return "grow_local_histmaker";
|
||||
}
|
||||
explicit CQHistMaker(GenericParameter const *ctx) : HistMaker(ctx) {}
|
||||
char const *Name() const override { return "grow_local_histmaker"; }
|
||||
|
||||
protected:
|
||||
struct HistEntry {
|
||||
@ -624,9 +623,7 @@ class CQHistMaker: public HistMaker {
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
||||
.describe("Tree constructor that uses approximate histogram construction.")
|
||||
.set_body([](ObjInfo) {
|
||||
return new CQHistMaker();
|
||||
});
|
||||
.describe("Tree constructor that uses approximate histogram construction.")
|
||||
.set_body([](GenericParameter const *ctx, ObjInfo) { return new CQHistMaker(ctx); });
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -21,9 +21,9 @@ namespace tree {
|
||||
DMLC_REGISTRY_FILE_TAG(updater_prune);
|
||||
|
||||
/*! \brief pruner that prunes a tree after growing finishes */
|
||||
class TreePruner: public TreeUpdater {
|
||||
class TreePruner : public TreeUpdater {
|
||||
public:
|
||||
explicit TreePruner(ObjInfo task) {
|
||||
explicit TreePruner(GenericParameter const* ctx, ObjInfo task) : TreeUpdater(ctx) {
|
||||
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
|
||||
pruner_monitor_.Init("TreePruner");
|
||||
}
|
||||
@ -112,9 +112,7 @@ class TreePruner: public TreeUpdater {
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
||||
.describe("Pruner that prune the tree according to statistics.")
|
||||
.set_body([](ObjInfo task) {
|
||||
return new TreePruner(task);
|
||||
});
|
||||
.describe("Pruner that prune the tree according to statistics.")
|
||||
.set_body([](GenericParameter const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -411,6 +411,8 @@ template struct QuantileHistMaker::Builder<double>;
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
||||
.describe("Grow tree using quantized histogram.")
|
||||
.set_body([](ObjInfo task) { return new QuantileHistMaker(task); });
|
||||
.set_body([](GenericParameter const *ctx, ObjInfo task) {
|
||||
return new QuantileHistMaker(ctx, task);
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -235,7 +235,8 @@ inline BatchParam HistBatch(TrainParam const& param) {
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker: public TreeUpdater {
|
||||
public:
|
||||
explicit QuantileHistMaker(ObjInfo task) : task_{task} {}
|
||||
explicit QuantileHistMaker(GenericParameter const* ctx, ObjInfo task)
|
||||
: task_{task}, TreeUpdater(ctx) {}
|
||||
void Configure(const Args& args) override;
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
|
||||
@ -22,11 +22,10 @@ namespace tree {
|
||||
DMLC_REGISTRY_FILE_TAG(updater_refresh);
|
||||
|
||||
/*! \brief pruner that prunes a tree after growing finishs */
|
||||
class TreeRefresher: public TreeUpdater {
|
||||
class TreeRefresher : public TreeUpdater {
|
||||
public:
|
||||
void Configure(const Args& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
explicit TreeRefresher(GenericParameter const *ctx) : TreeUpdater(ctx) {}
|
||||
void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); }
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("train_param"), &this->param_);
|
||||
@ -160,9 +159,7 @@ class TreeRefresher: public TreeUpdater {
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||
.set_body([](ObjInfo) {
|
||||
return new TreeRefresher();
|
||||
});
|
||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||
.set_body([](GenericParameter const *ctx, ObjInfo) { return new TreeRefresher(ctx); });
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -20,8 +20,9 @@ DMLC_REGISTRY_FILE_TAG(updater_sync);
|
||||
* \brief syncher that synchronize the tree in all distributed nodes
|
||||
* can implement various strategies, so far it is always set to node 0's tree
|
||||
*/
|
||||
class TreeSyncher: public TreeUpdater {
|
||||
class TreeSyncher : public TreeUpdater {
|
||||
public:
|
||||
explicit TreeSyncher(GenericParameter const* tparam) : TreeUpdater(tparam) {}
|
||||
void Configure(const Args&) override {}
|
||||
|
||||
void LoadConfig(Json const&) override {}
|
||||
@ -52,9 +53,7 @@ class TreeSyncher: public TreeUpdater {
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync")
|
||||
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
||||
.set_body([](ObjInfo) {
|
||||
return new TreeSyncher();
|
||||
});
|
||||
.describe("Syncher that synchronize the tree in all distributed nodes.")
|
||||
.set_body([](GenericParameter const* tparam, ObjInfo) { return new TreeSyncher(tparam); });
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -277,8 +277,10 @@ void TestHistogramIndexImpl() {
|
||||
int constexpr kNRows = 1000, kNCols = 10;
|
||||
|
||||
// Build 2 matrices and build a histogram maker with that
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}},
|
||||
hist_maker_ext{ObjInfo{ObjInfo::kRegression}};
|
||||
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}},
|
||||
hist_maker_ext{&generic_param,ObjInfo{ObjInfo::kRegression}};
|
||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
||||
|
||||
@ -291,10 +293,9 @@ void TestHistogramIndexImpl() {
|
||||
{"max_leaves", "0"}
|
||||
};
|
||||
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
hist_maker.Configure(training_params, &generic_param);
|
||||
hist_maker.Configure(training_params);
|
||||
hist_maker.InitDataOnce(hist_maker_dmat.get());
|
||||
hist_maker_ext.Configure(training_params, &generic_param);
|
||||
hist_maker_ext.Configure(training_params);
|
||||
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
|
||||
|
||||
// Extract the device maker from the histogram makers and from that its compressed
|
||||
@ -346,9 +347,9 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
{"sampling_method", sampling_method},
|
||||
};
|
||||
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
hist_maker.Configure(args, &generic_param);
|
||||
tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}};
|
||||
hist_maker.Configure(args);
|
||||
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
hist_maker.Update(gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, {tree});
|
||||
|
||||
@ -16,11 +16,11 @@ class TestGPUBasicModels:
|
||||
cpu_test_bm = test_bm.TestModels()
|
||||
|
||||
def run_cls(self, X, y):
|
||||
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True)
|
||||
cls = xgb.XGBClassifier(tree_method='gpu_hist')
|
||||
cls.fit(X, y)
|
||||
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
|
||||
|
||||
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True)
|
||||
cls = xgb.XGBClassifier(tree_method='gpu_hist')
|
||||
cls.fit(X, y)
|
||||
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@ parameter_strategy = strategies.fixed_dictionaries({
|
||||
'max_leaves': strategies.integers(0, 256),
|
||||
'max_bin': strategies.integers(2, 1024),
|
||||
'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']),
|
||||
'single_precision_histogram': strategies.booleans(),
|
||||
'min_child_weight': strategies.floats(0.5, 2.0),
|
||||
'seed': strategies.integers(0, 10),
|
||||
# We cannot enable subsampling as the training loss can increase
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user