Remove omp_get_max_threads in gbm and linear. (#7537)
* Use ctx in gbm. * Use ctx threads in gbm and linear.
This commit is contained in:
parent
eea094e1bc
commit
28af6f9abb
@ -38,7 +38,7 @@ class PredictionContainer;
|
|||||||
*/
|
*/
|
||||||
class GradientBooster : public Model, public Configurable {
|
class GradientBooster : public Model, public Configurable {
|
||||||
protected:
|
protected:
|
||||||
GenericParameter const* generic_param_;
|
GenericParameter const* ctx_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class GBLinearModel;
|
|||||||
*/
|
*/
|
||||||
class LinearUpdater : public Configurable {
|
class LinearUpdater : public Configurable {
|
||||||
protected:
|
protected:
|
||||||
GenericParameter const* learner_param_;
|
GenericParameter const* ctx_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
|
|||||||
@ -86,7 +86,7 @@ class GBLinear : public GradientBooster {
|
|||||||
}
|
}
|
||||||
param_.UpdateAllowUnknown(cfg);
|
param_.UpdateAllowUnknown(cfg);
|
||||||
param_.CheckGPUSupport();
|
param_.CheckGPUSupport();
|
||||||
updater_.reset(LinearUpdater::Create(param_.updater, generic_param_));
|
updater_.reset(LinearUpdater::Create(param_.updater, ctx_));
|
||||||
updater_->Configure(cfg);
|
updater_->Configure(cfg);
|
||||||
monitor_.Init("GBLinear");
|
monitor_.Init("GBLinear");
|
||||||
}
|
}
|
||||||
@ -120,7 +120,7 @@ class GBLinear : public GradientBooster {
|
|||||||
CHECK_EQ(get<String>(in["name"]), "gblinear");
|
CHECK_EQ(get<String>(in["name"]), "gblinear");
|
||||||
FromJson(in["gblinear_train_param"], ¶m_);
|
FromJson(in["gblinear_train_param"], ¶m_);
|
||||||
param_.CheckGPUSupport();
|
param_.CheckGPUSupport();
|
||||||
updater_.reset(LinearUpdater::Create(param_.updater, generic_param_));
|
updater_.reset(LinearUpdater::Create(param_.updater, ctx_));
|
||||||
this->updater_->LoadConfig(in["updater"]);
|
this->updater_->LoadConfig(in["updater"]);
|
||||||
}
|
}
|
||||||
void SaveConfig(Json* p_out) const override {
|
void SaveConfig(Json* p_out) const override {
|
||||||
|
|||||||
@ -26,7 +26,7 @@ GradientBooster* GradientBooster::Create(
|
|||||||
LOG(FATAL) << "Unknown gbm type " << name;
|
LOG(FATAL) << "Unknown gbm type " << name;
|
||||||
}
|
}
|
||||||
auto p_bst = (e->body)(learner_model_param);
|
auto p_bst = (e->body)(learner_model_param);
|
||||||
p_bst->generic_param_ = generic_param;
|
p_bst->ctx_ = generic_param;
|
||||||
return p_bst;
|
return p_bst;
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -49,14 +49,14 @@ void GBTree::Configure(const Args& cfg) {
|
|||||||
// configure predictors
|
// configure predictors
|
||||||
if (!cpu_predictor_) {
|
if (!cpu_predictor_) {
|
||||||
cpu_predictor_ = std::unique_ptr<Predictor>(
|
cpu_predictor_ = std::unique_ptr<Predictor>(
|
||||||
Predictor::Create("cpu_predictor", this->generic_param_));
|
Predictor::Create("cpu_predictor", this->ctx_));
|
||||||
}
|
}
|
||||||
cpu_predictor_->Configure(cfg);
|
cpu_predictor_->Configure(cfg);
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
auto n_gpus = common::AllVisibleGPUs();
|
auto n_gpus = common::AllVisibleGPUs();
|
||||||
if (!gpu_predictor_ && n_gpus != 0) {
|
if (!gpu_predictor_ && n_gpus != 0) {
|
||||||
gpu_predictor_ = std::unique_ptr<Predictor>(
|
gpu_predictor_ = std::unique_ptr<Predictor>(
|
||||||
Predictor::Create("gpu_predictor", this->generic_param_));
|
Predictor::Create("gpu_predictor", this->ctx_));
|
||||||
}
|
}
|
||||||
if (n_gpus != 0) {
|
if (n_gpus != 0) {
|
||||||
gpu_predictor_->Configure(cfg);
|
gpu_predictor_->Configure(cfg);
|
||||||
@ -201,16 +201,16 @@ void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void CopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
|
void CopyGradient(HostDeviceVector<GradientPair> const* in_gpair, int32_t n_threads,
|
||||||
bst_group_t n_groups, bst_group_t group_id,
|
bst_group_t n_groups, bst_group_t group_id,
|
||||||
HostDeviceVector<GradientPair> *out_gpair) {
|
HostDeviceVector<GradientPair>* out_gpair) {
|
||||||
if (in_gpair->DeviceIdx() != GenericParameter::kCpuId) {
|
if (in_gpair->DeviceIdx() != GenericParameter::kCpuId) {
|
||||||
GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair);
|
GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair);
|
||||||
} else {
|
} else {
|
||||||
std::vector<GradientPair> &tmp_h = out_gpair->HostVector();
|
std::vector<GradientPair> &tmp_h = out_gpair->HostVector();
|
||||||
auto nsize = static_cast<bst_omp_uint>(out_gpair->Size());
|
auto nsize = static_cast<bst_omp_uint>(out_gpair->Size());
|
||||||
const auto &gpair_h = in_gpair->ConstHostVector();
|
const auto &gpair_h = in_gpair->ConstHostVector();
|
||||||
common::ParallelFor(nsize, [&](bst_omp_uint i) {
|
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
|
||||||
tmp_h[i] = gpair_h[i * n_groups + group_id];
|
tmp_h[i] = gpair_h[i * n_groups + group_id];
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -228,7 +228,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
|||||||
// break a lots of existing code.
|
// break a lots of existing code.
|
||||||
auto device = tparam_.tree_method != TreeMethod::kGPUHist
|
auto device = tparam_.tree_method != TreeMethod::kGPUHist
|
||||||
? GenericParameter::kCpuId
|
? GenericParameter::kCpuId
|
||||||
: generic_param_->gpu_id;
|
: ctx_->gpu_id;
|
||||||
auto out = linalg::TensorView<float, 2>{
|
auto out = linalg::TensorView<float, 2>{
|
||||||
device == GenericParameter::kCpuId ? predt->predictions.HostSpan()
|
device == GenericParameter::kCpuId ? predt->predictions.HostSpan()
|
||||||
: predt->predictions.DeviceSpan(),
|
: predt->predictions.DeviceSpan(),
|
||||||
@ -255,7 +255,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
|||||||
in_gpair->DeviceIdx());
|
in_gpair->DeviceIdx());
|
||||||
bool update_predict = true;
|
bool update_predict = true;
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
CopyGradient(in_gpair, ngroup, gid, &tmp);
|
CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp);
|
||||||
std::vector<std::unique_ptr<RegTree> > ret;
|
std::vector<std::unique_ptr<RegTree> > ret;
|
||||||
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
||||||
const size_t num_new_trees = ret.size();
|
const size_t num_new_trees = ret.size();
|
||||||
@ -310,7 +310,7 @@ void GBTree::InitUpdater(Args const& cfg) {
|
|||||||
// create new updaters
|
// create new updaters
|
||||||
for (const std::string& pstr : ups) {
|
for (const std::string& pstr : ups) {
|
||||||
std::unique_ptr<TreeUpdater> up(
|
std::unique_ptr<TreeUpdater> up(
|
||||||
TreeUpdater::Create(pstr.c_str(), generic_param_, model_.learner_model_param->task));
|
TreeUpdater::Create(pstr.c_str(), ctx_, model_.learner_model_param->task));
|
||||||
up->Configure(cfg);
|
up->Configure(cfg);
|
||||||
updaters_.push_back(std::move(up));
|
updaters_.push_back(std::move(up));
|
||||||
}
|
}
|
||||||
@ -396,7 +396,7 @@ void GBTree::LoadConfig(Json const& in) {
|
|||||||
updaters_.clear();
|
updaters_.clear();
|
||||||
for (auto const& kv : j_updaters) {
|
for (auto const& kv : j_updaters) {
|
||||||
std::unique_ptr<TreeUpdater> up(
|
std::unique_ptr<TreeUpdater> up(
|
||||||
TreeUpdater::Create(kv.first, generic_param_, model_.learner_model_param->task));
|
TreeUpdater::Create(kv.first, ctx_, model_.learner_model_param->task));
|
||||||
up->LoadConfig(kv.second);
|
up->LoadConfig(kv.second);
|
||||||
updaters_.push_back(std::move(up));
|
updaters_.push_back(std::move(up));
|
||||||
}
|
}
|
||||||
@ -562,7 +562,7 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
|||||||
auto on_device = is_ellpack || is_from_device;
|
auto on_device = is_ellpack || is_from_device;
|
||||||
|
|
||||||
// Use GPU Predictor if data is already on device and gpu_id is set.
|
// Use GPU Predictor if data is already on device and gpu_id is set.
|
||||||
if (on_device && generic_param_->gpu_id >= 0) {
|
if (on_device && ctx_->gpu_id >= 0) {
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
|
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
|
||||||
CHECK(gpu_predictor_);
|
CHECK(gpu_predictor_);
|
||||||
@ -728,8 +728,8 @@ class Dart : public GBTree {
|
|||||||
auto n_groups = model_.learner_model_param->num_output_group;
|
auto n_groups = model_.learner_model_param->num_output_group;
|
||||||
|
|
||||||
PredictionCacheEntry predts; // temporary storage for prediction
|
PredictionCacheEntry predts; // temporary storage for prediction
|
||||||
if (generic_param_->gpu_id != GenericParameter::kCpuId) {
|
if (ctx_->gpu_id != GenericParameter::kCpuId) {
|
||||||
predts.predictions.SetDevice(generic_param_->gpu_id);
|
predts.predictions.SetDevice(ctx_->gpu_id);
|
||||||
}
|
}
|
||||||
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||||
|
|
||||||
@ -758,11 +758,10 @@ class Dart : public GBTree {
|
|||||||
} else {
|
} else {
|
||||||
auto &h_out_predts = p_out_preds->predictions.HostVector();
|
auto &h_out_predts = p_out_preds->predictions.HostVector();
|
||||||
auto &h_predts = predts.predictions.HostVector();
|
auto &h_predts = predts.predictions.HostVector();
|
||||||
#pragma omp parallel for
|
common::ParallelFor(p_fmat->Info().num_row_, ctx_->Threads(), [&](auto ridx) {
|
||||||
for (omp_ulong ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) {
|
|
||||||
const size_t offset = ridx * n_groups + group;
|
const size_t offset = ridx * n_groups + group;
|
||||||
h_out_predts[offset] += (h_predts[offset] * w);
|
h_out_predts[offset] += (h_predts[offset] * w);
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -846,13 +845,11 @@ class Dart : public GBTree {
|
|||||||
if (device == GenericParameter::kCpuId) {
|
if (device == GenericParameter::kCpuId) {
|
||||||
auto &h_predts = predts.predictions.HostVector();
|
auto &h_predts = predts.predictions.HostVector();
|
||||||
auto &h_out_predts = out_preds->predictions.HostVector();
|
auto &h_out_predts = out_preds->predictions.HostVector();
|
||||||
#pragma omp parallel for
|
common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) {
|
||||||
for (omp_ulong ridx = 0; ridx < n_rows; ++ridx) {
|
|
||||||
const size_t offset = ridx * n_groups + group;
|
const size_t offset = ridx * n_groups + group;
|
||||||
// Need to remove the base margin from individual tree.
|
// Need to remove the base margin from individual tree.
|
||||||
h_out_predts[offset] +=
|
h_out_predts[offset] += (h_predts[offset] - model_.learner_model_param->base_score) * w;
|
||||||
(h_predts[offset] - model_.learner_model_param->base_score) * w;
|
});
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
out_preds->predictions.SetDevice(device);
|
out_preds->predictions.SetDevice(device);
|
||||||
predts.predictions.SetDevice(device);
|
predts.predictions.SetDevice(device);
|
||||||
|
|||||||
@ -413,10 +413,9 @@ class GBTree : public GradientBooster {
|
|||||||
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||||
bool with_stats,
|
|
||||||
std::string format) const override {
|
std::string format) const override {
|
||||||
return model_.DumpModel(fmap, with_stats, format);
|
return model_.DumpModel(fmap, with_stats, this->ctx_->Threads(), format);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|||||||
@ -109,12 +109,11 @@ struct GBTreeModel : public Model {
|
|||||||
void SaveModel(Json* p_out) const override;
|
void SaveModel(Json* p_out) const override;
|
||||||
void LoadModel(Json const& p_out) override;
|
void LoadModel(Json const& p_out) override;
|
||||||
|
|
||||||
std::vector<std::string> DumpModel(const FeatureMap &fmap, bool with_stats,
|
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats, int32_t n_threads,
|
||||||
std::string format) const {
|
std::string format) const {
|
||||||
std::vector<std::string> dump(trees.size());
|
std::vector<std::string> dump(trees.size());
|
||||||
common::ParallelFor(static_cast<omp_ulong>(trees.size()), [&](size_t i) {
|
common::ParallelFor(trees.size(), n_threads,
|
||||||
dump[i] = trees[i]->DumpModel(fmap, with_stats, format);
|
[&](size_t i) { dump[i] = trees[i]->DumpModel(fmap, with_stats, format); });
|
||||||
});
|
|
||||||
return dump;
|
return dump;
|
||||||
}
|
}
|
||||||
void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
|
void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
|
||||||
|
|||||||
@ -149,21 +149,21 @@ GetGradientParallel(GenericParameter const *ctx, int group_idx, int num_group,
|
|||||||
*/
|
*/
|
||||||
inline std::pair<double, double> GetBiasGradientParallel(int group_idx, int num_group,
|
inline std::pair<double, double> GetBiasGradientParallel(int group_idx, int num_group,
|
||||||
const std::vector<GradientPair> &gpair,
|
const std::vector<GradientPair> &gpair,
|
||||||
DMatrix *p_fmat) {
|
DMatrix *p_fmat, int32_t n_threads) {
|
||||||
double sum_grad = 0.0, sum_hess = 0.0;
|
|
||||||
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
||||||
dmlc::OMPException exc;
|
std::vector<double> sum_grad_tloc(n_threads, 0);
|
||||||
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
|
std::vector<double> sum_hess_tloc(n_threads, 0);
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
|
||||||
exc.Run([&]() {
|
common::ParallelFor(ndata, n_threads, [&](auto i) {
|
||||||
auto &p = gpair[i * num_group + group_idx];
|
auto tid = omp_get_thread_num();
|
||||||
if (p.GetHess() >= 0.0f) {
|
auto &p = gpair[i * num_group + group_idx];
|
||||||
sum_grad += p.GetGrad();
|
if (p.GetHess() >= 0.0f) {
|
||||||
sum_hess += p.GetHess();
|
sum_grad_tloc[tid] += p.GetGrad();
|
||||||
}
|
sum_hess_tloc[tid] += p.GetHess();
|
||||||
});
|
}
|
||||||
}
|
});
|
||||||
exc.Rethrow();
|
double sum_grad = std::accumulate(sum_grad_tloc.cbegin(), sum_grad_tloc.cend(), 0.0);
|
||||||
|
double sum_hess = std::accumulate(sum_hess_tloc.cbegin(), sum_hess_tloc.cend(), 0.0);
|
||||||
return std::make_pair(sum_grad, sum_hess);
|
return std::make_pair(sum_grad, sum_hess);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,23 +179,18 @@ inline std::pair<double, double> GetBiasGradientParallel(int group_idx, int num_
|
|||||||
*/
|
*/
|
||||||
inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
||||||
float dw, std::vector<GradientPair> *in_gpair,
|
float dw, std::vector<GradientPair> *in_gpair,
|
||||||
DMatrix *p_fmat) {
|
DMatrix *p_fmat, int32_t n_threads) {
|
||||||
if (dw == 0.0f) return;
|
if (dw == 0.0f) return;
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
auto col = page[fidx];
|
auto col = page[fidx];
|
||||||
// update grad value
|
// update grad value
|
||||||
const auto num_row = static_cast<bst_omp_uint>(col.size());
|
const auto num_row = static_cast<bst_omp_uint>(col.size());
|
||||||
dmlc::OMPException exc;
|
common::ParallelFor(num_row, n_threads, [&](auto j) {
|
||||||
#pragma omp parallel for schedule(static)
|
GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx];
|
||||||
for (bst_omp_uint j = 0; j < num_row; ++j) {
|
if (p.GetHess() < 0.0f) return;
|
||||||
exc.Run([&]() {
|
p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0);
|
||||||
GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx];
|
});
|
||||||
if (p.GetHess() < 0.0f) return;
|
|
||||||
p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
exc.Rethrow();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,20 +204,15 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
|||||||
* \param p_fmat The input feature matrix.
|
* \param p_fmat The input feature matrix.
|
||||||
*/
|
*/
|
||||||
inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias,
|
inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias,
|
||||||
std::vector<GradientPair> *in_gpair,
|
std::vector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||||
DMatrix *p_fmat) {
|
int32_t n_threads) {
|
||||||
if (dbias == 0.0f) return;
|
if (dbias == 0.0f) return;
|
||||||
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
||||||
dmlc::OMPException exc;
|
common::ParallelFor(ndata, n_threads, [&](auto i) {
|
||||||
#pragma omp parallel for schedule(static)
|
GradientPair &g = (*in_gpair)[i * num_group + group_idx];
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
if (g.GetHess() < 0.0f) return;
|
||||||
exc.Run([&]() {
|
g += GradientPair(g.GetHess() * dbias, 0);
|
||||||
GradientPair &g = (*in_gpair)[i * num_group + group_idx];
|
});
|
||||||
if (g.GetHess() < 0.0f) return;
|
|
||||||
g += GradientPair(g.GetHess() * dbias, 0);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
exc.Rethrow();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -230,9 +220,13 @@ inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias
|
|||||||
* in coordinate descent algorithms.
|
* in coordinate descent algorithms.
|
||||||
*/
|
*/
|
||||||
class FeatureSelector {
|
class FeatureSelector {
|
||||||
|
protected:
|
||||||
|
int32_t n_threads_{-1};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
explicit FeatureSelector(int32_t n_threads) : n_threads_{n_threads} {}
|
||||||
/*! \brief factory method */
|
/*! \brief factory method */
|
||||||
static FeatureSelector *Create(int choice);
|
static FeatureSelector *Create(int choice, int32_t n_threads);
|
||||||
/*! \brief virtual destructor */
|
/*! \brief virtual destructor */
|
||||||
virtual ~FeatureSelector() = default;
|
virtual ~FeatureSelector() = default;
|
||||||
/**
|
/**
|
||||||
@ -274,6 +268,7 @@ class FeatureSelector {
|
|||||||
*/
|
*/
|
||||||
class CyclicFeatureSelector : public FeatureSelector {
|
class CyclicFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
|
using FeatureSelector::FeatureSelector;
|
||||||
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
int NextFeature(int iteration, const gbm::GBLinearModel &model,
|
||||||
int , const std::vector<GradientPair> &,
|
int , const std::vector<GradientPair> &,
|
||||||
DMatrix *, float, float) override {
|
DMatrix *, float, float) override {
|
||||||
@ -287,6 +282,7 @@ class CyclicFeatureSelector : public FeatureSelector {
|
|||||||
*/
|
*/
|
||||||
class ShuffleFeatureSelector : public FeatureSelector {
|
class ShuffleFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
|
using FeatureSelector::FeatureSelector;
|
||||||
void Setup(const gbm::GBLinearModel &model,
|
void Setup(const gbm::GBLinearModel &model,
|
||||||
const std::vector<GradientPair>&,
|
const std::vector<GradientPair>&,
|
||||||
DMatrix *, float, float, int) override {
|
DMatrix *, float, float, int) override {
|
||||||
@ -313,6 +309,7 @@ class ShuffleFeatureSelector : public FeatureSelector {
|
|||||||
*/
|
*/
|
||||||
class RandomFeatureSelector : public FeatureSelector {
|
class RandomFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
|
using FeatureSelector::FeatureSelector;
|
||||||
int NextFeature(int, const gbm::GBLinearModel &model,
|
int NextFeature(int, const gbm::GBLinearModel &model,
|
||||||
int, const std::vector<GradientPair> &,
|
int, const std::vector<GradientPair> &,
|
||||||
DMatrix *, float, float) override {
|
DMatrix *, float, float) override {
|
||||||
@ -331,6 +328,7 @@ class RandomFeatureSelector : public FeatureSelector {
|
|||||||
*/
|
*/
|
||||||
class GreedyFeatureSelector : public FeatureSelector {
|
class GreedyFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
|
using FeatureSelector::FeatureSelector;
|
||||||
void Setup(const gbm::GBLinearModel &model,
|
void Setup(const gbm::GBLinearModel &model,
|
||||||
const std::vector<GradientPair> &,
|
const std::vector<GradientPair> &,
|
||||||
DMatrix *, float, float, int param) override {
|
DMatrix *, float, float, int param) override {
|
||||||
@ -360,7 +358,7 @@ class GreedyFeatureSelector : public FeatureSelector {
|
|||||||
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
common::ParallelFor(nfeat, [&](bst_omp_uint i) {
|
common::ParallelFor(nfeat, this->n_threads_, [&](bst_omp_uint i) {
|
||||||
const auto col = page[i];
|
const auto col = page[i];
|
||||||
const bst_uint ndata = col.size();
|
const bst_uint ndata = col.size();
|
||||||
auto &sums = gpair_sums_[group_idx * nfeat + i];
|
auto &sums = gpair_sums_[group_idx * nfeat + i];
|
||||||
@ -407,6 +405,7 @@ class GreedyFeatureSelector : public FeatureSelector {
|
|||||||
*/
|
*/
|
||||||
class ThriftyFeatureSelector : public FeatureSelector {
|
class ThriftyFeatureSelector : public FeatureSelector {
|
||||||
public:
|
public:
|
||||||
|
using FeatureSelector::FeatureSelector;
|
||||||
void Setup(const gbm::GBLinearModel &model,
|
void Setup(const gbm::GBLinearModel &model,
|
||||||
const std::vector<GradientPair> &gpair,
|
const std::vector<GradientPair> &gpair,
|
||||||
DMatrix *p_fmat, float alpha, float lambda, int param) override {
|
DMatrix *p_fmat, float alpha, float lambda, int param) override {
|
||||||
@ -426,7 +425,7 @@ class ThriftyFeatureSelector : public FeatureSelector {
|
|||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
// column-parallel is usually fastaer than row-parallel
|
// column-parallel is usually fastaer than row-parallel
|
||||||
common::ParallelFor(nfeat, [&](bst_omp_uint i) {
|
common::ParallelFor(nfeat, this->n_threads_, [&](auto i) {
|
||||||
const auto col = page[i];
|
const auto col = page[i];
|
||||||
const bst_uint ndata = col.size();
|
const bst_uint ndata = col.size();
|
||||||
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
||||||
@ -483,18 +482,18 @@ class ThriftyFeatureSelector : public FeatureSelector {
|
|||||||
std::vector<std::pair<double, double>> gpair_sums_;
|
std::vector<std::pair<double, double>> gpair_sums_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline FeatureSelector *FeatureSelector::Create(int choice) {
|
inline FeatureSelector *FeatureSelector::Create(int choice, int32_t n_threads) {
|
||||||
switch (choice) {
|
switch (choice) {
|
||||||
case kCyclic:
|
case kCyclic:
|
||||||
return new CyclicFeatureSelector();
|
return new CyclicFeatureSelector(n_threads);
|
||||||
case kShuffle:
|
case kShuffle:
|
||||||
return new ShuffleFeatureSelector();
|
return new ShuffleFeatureSelector(n_threads);
|
||||||
case kThrifty:
|
case kThrifty:
|
||||||
return new ThriftyFeatureSelector();
|
return new ThriftyFeatureSelector(n_threads);
|
||||||
case kGreedy:
|
case kGreedy:
|
||||||
return new GreedyFeatureSelector();
|
return new GreedyFeatureSelector(n_threads);
|
||||||
case kRandom:
|
case kRandom:
|
||||||
return new RandomFeatureSelector();
|
return new RandomFeatureSelector(n_threads);
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unknown coordinate selector: " << choice;
|
LOG(FATAL) << "unknown coordinate selector: " << choice;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,7 +17,7 @@ LinearUpdater* LinearUpdater::Create(const std::string& name, GenericParameter c
|
|||||||
LOG(FATAL) << "Unknown linear updater " << name;
|
LOG(FATAL) << "Unknown linear updater " << name;
|
||||||
}
|
}
|
||||||
auto p_linear = (e->body)();
|
auto p_linear = (e->body)();
|
||||||
p_linear->learner_param_ = lparam;
|
p_linear->ctx_ = lparam;
|
||||||
return p_linear;
|
return p_linear;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
tparam_.UpdateAllowUnknown(args)
|
tparam_.UpdateAllowUnknown(args)
|
||||||
};
|
};
|
||||||
cparam_.UpdateAllowUnknown(rest);
|
cparam_.UpdateAllowUnknown(rest);
|
||||||
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
|
selector_.reset(FeatureSelector::Create(tparam_.feature_selector, ctx_->Threads()));
|
||||||
monitor_.Init("CoordinateUpdater");
|
monitor_.Init("CoordinateUpdater");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,13 +51,13 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
const int ngroup = model->learner_model_param->num_output_group;
|
const int ngroup = model->learner_model_param->num_output_group;
|
||||||
// update bias
|
// update bias
|
||||||
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
||||||
auto grad = GetBiasGradientParallel(group_idx, ngroup,
|
auto grad = GetBiasGradientParallel(group_idx, ngroup, in_gpair->ConstHostVector(), p_fmat,
|
||||||
in_gpair->ConstHostVector(), p_fmat);
|
ctx_->Threads());
|
||||||
auto dbias = static_cast<float>(tparam_.learning_rate *
|
auto dbias = static_cast<float>(tparam_.learning_rate *
|
||||||
CoordinateDeltaBias(grad.first, grad.second));
|
CoordinateDeltaBias(grad.first, grad.second));
|
||||||
model->Bias()[group_idx] += dbias;
|
model->Bias()[group_idx] += dbias;
|
||||||
UpdateBiasResidualParallel(group_idx, ngroup,
|
UpdateBiasResidualParallel(group_idx, ngroup, dbias, &in_gpair->HostVector(), p_fmat,
|
||||||
dbias, &in_gpair->HostVector(), p_fmat);
|
ctx_->Threads());
|
||||||
}
|
}
|
||||||
// prepare for updating the weights
|
// prepare for updating the weights
|
||||||
selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat,
|
selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat,
|
||||||
@ -80,14 +80,15 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
||||||
const int ngroup = model->learner_model_param->num_output_group;
|
const int ngroup = model->learner_model_param->num_output_group;
|
||||||
bst_float &w = (*model)[fidx][group_idx];
|
bst_float &w = (*model)[fidx][group_idx];
|
||||||
auto gradient = GetGradientParallel(learner_param_, group_idx, ngroup, fidx,
|
auto gradient = GetGradientParallel(ctx_, group_idx, ngroup, fidx,
|
||||||
*in_gpair, p_fmat);
|
*in_gpair, p_fmat);
|
||||||
auto dw = static_cast<float>(
|
auto dw = static_cast<float>(
|
||||||
tparam_.learning_rate *
|
tparam_.learning_rate *
|
||||||
CoordinateDelta(gradient.first, gradient.second, w, tparam_.reg_alpha_denorm,
|
CoordinateDelta(gradient.first, gradient.second, w, tparam_.reg_alpha_denorm,
|
||||||
tparam_.reg_lambda_denorm));
|
tparam_.reg_lambda_denorm));
|
||||||
w += dw;
|
w += dw;
|
||||||
UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat);
|
UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat,
|
||||||
|
ctx_->Threads());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
void Configure(Args const& args) override {
|
void Configure(Args const& args) override {
|
||||||
tparam_.UpdateAllowUnknown(args);
|
tparam_.UpdateAllowUnknown(args);
|
||||||
coord_param_.UpdateAllowUnknown(args);
|
coord_param_.UpdateAllowUnknown(args);
|
||||||
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
|
selector_.reset(FeatureSelector::Create(tparam_.feature_selector, ctx_->Threads()));
|
||||||
monitor_.Init("GPUCoordinateUpdater");
|
monitor_.Init("GPUCoordinateUpdater");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LazyInitDevice(DMatrix *p_fmat, const LearnerModelParam &model_param) {
|
void LazyInitDevice(DMatrix *p_fmat, const LearnerModelParam &model_param) {
|
||||||
if (learner_param_->gpu_id < 0) return;
|
if (ctx_->gpu_id < 0) return;
|
||||||
|
|
||||||
num_row_ = static_cast<size_t>(p_fmat->Info().num_row_);
|
num_row_ = static_cast<size_t>(p_fmat->Info().num_row_);
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
// The begin and end indices for the section of each column associated with
|
// The begin and end indices for the section of each column associated with
|
||||||
// this device
|
// this device
|
||||||
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
|
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
|
||||||
@ -103,7 +103,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
monitor_.Start("UpdateGpair");
|
monitor_.Start("UpdateGpair");
|
||||||
auto &in_gpair_host = in_gpair->ConstHostVector();
|
auto &in_gpair_host = in_gpair->ConstHostVector();
|
||||||
// Update gpair
|
// Update gpair
|
||||||
if (learner_param_->gpu_id >= 0) {
|
if (ctx_->gpu_id >= 0) {
|
||||||
this->UpdateGpair(in_gpair_host);
|
this->UpdateGpair(in_gpair_host);
|
||||||
}
|
}
|
||||||
monitor_.Stop("UpdateGpair");
|
monitor_.Stop("UpdateGpair");
|
||||||
@ -134,7 +134,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
++group_idx) {
|
++group_idx) {
|
||||||
// Get gradient
|
// Get gradient
|
||||||
auto grad = GradientPair(0, 0);
|
auto grad = GradientPair(0, 0);
|
||||||
if (learner_param_->gpu_id >= 0) {
|
if (ctx_->gpu_id >= 0) {
|
||||||
grad = GetBiasGradient(group_idx, model->learner_model_param->num_output_group);
|
grad = GetBiasGradient(group_idx, model->learner_model_param->num_output_group);
|
||||||
}
|
}
|
||||||
auto dbias = static_cast<float>(
|
auto dbias = static_cast<float>(
|
||||||
@ -143,7 +143,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
model->Bias()[group_idx] += dbias;
|
model->Bias()[group_idx] += dbias;
|
||||||
|
|
||||||
// Update residual
|
// Update residual
|
||||||
if (learner_param_->gpu_id >= 0) {
|
if (ctx_->gpu_id >= 0) {
|
||||||
UpdateBiasResidual(dbias, group_idx, model->learner_model_param->num_output_group);
|
UpdateBiasResidual(dbias, group_idx, model->learner_model_param->num_output_group);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -155,7 +155,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
bst_float &w = (*model)[fidx][group_idx];
|
bst_float &w = (*model)[fidx][group_idx];
|
||||||
// Get gradient
|
// Get gradient
|
||||||
auto grad = GradientPair(0, 0);
|
auto grad = GradientPair(0, 0);
|
||||||
if (learner_param_->gpu_id >= 0) {
|
if (ctx_->gpu_id >= 0) {
|
||||||
grad = GetGradient(group_idx, model->learner_model_param->num_output_group, fidx);
|
grad = GetGradient(group_idx, model->learner_model_param->num_output_group, fidx);
|
||||||
}
|
}
|
||||||
auto dw = static_cast<float>(tparam_.learning_rate *
|
auto dw = static_cast<float>(tparam_.learning_rate *
|
||||||
@ -164,14 +164,14 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
tparam_.reg_lambda_denorm));
|
tparam_.reg_lambda_denorm));
|
||||||
w += dw;
|
w += dw;
|
||||||
|
|
||||||
if (learner_param_->gpu_id >= 0) {
|
if (ctx_->gpu_id >= 0) {
|
||||||
UpdateResidual(dw, group_idx, model->learner_model_param->num_output_group, fidx);
|
UpdateResidual(dw, group_idx, model->learner_model_param->num_output_group, fidx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This needs to be public because of the __device__ lambda.
|
// This needs to be public because of the __device__ lambda.
|
||||||
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
||||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
auto counting = thrust::make_counting_iterator(0ull);
|
auto counting = thrust::make_counting_iterator(0ull);
|
||||||
auto f = [=] __device__(size_t idx) {
|
auto f = [=] __device__(size_t idx) {
|
||||||
return idx * num_group + group_idx;
|
return idx * num_group + group_idx;
|
||||||
@ -195,7 +195,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
|
|
||||||
// This needs to be public because of the __device__ lambda.
|
// This needs to be public because of the __device__ lambda.
|
||||||
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
||||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
common::Span<xgboost::Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
|
common::Span<xgboost::Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
|
||||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||||
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
|
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class ShotgunUpdater : public LinearUpdater {
|
|||||||
LOG(FATAL) << "Unsupported feature selector for shotgun updater.\n"
|
LOG(FATAL) << "Unsupported feature selector for shotgun updater.\n"
|
||||||
<< "Supported options are: {cyclic, shuffle}";
|
<< "Supported options are: {cyclic, shuffle}";
|
||||||
}
|
}
|
||||||
selector_.reset(FeatureSelector::Create(param_.feature_selector));
|
selector_.reset(FeatureSelector::Create(param_.feature_selector, ctx_->Threads()));
|
||||||
}
|
}
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
auto const& config = get<Object const>(in);
|
auto const& config = get<Object const>(in);
|
||||||
@ -40,12 +40,13 @@ class ShotgunUpdater : public LinearUpdater {
|
|||||||
|
|
||||||
// update bias
|
// update bias
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
auto grad = GetBiasGradientParallel(gid, ngroup,
|
auto grad = GetBiasGradientParallel(gid, ngroup, in_gpair->ConstHostVector(), p_fmat,
|
||||||
in_gpair->ConstHostVector(), p_fmat);
|
ctx_->Threads());
|
||||||
auto dbias = static_cast<bst_float>(param_.learning_rate *
|
auto dbias = static_cast<bst_float>(param_.learning_rate *
|
||||||
CoordinateDeltaBias(grad.first, grad.second));
|
CoordinateDeltaBias(grad.first, grad.second));
|
||||||
model->Bias()[gid] += dbias;
|
model->Bias()[gid] += dbias;
|
||||||
UpdateBiasResidualParallel(gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat);
|
UpdateBiasResidualParallel(gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat,
|
||||||
|
ctx_->Threads());
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock-free parallel updates of weights
|
// lock-free parallel updates of weights
|
||||||
@ -54,42 +55,35 @@ class ShotgunUpdater : public LinearUpdater {
|
|||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
|
const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
|
||||||
dmlc::OMPException exc;
|
common::ParallelFor(nfeat, ctx_->Threads(), [&](auto i) {
|
||||||
#pragma omp parallel for schedule(static)
|
int ii = selector_->NextFeature(i, *model, 0, in_gpair->ConstHostVector(), p_fmat,
|
||||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
param_.reg_alpha_denorm, param_.reg_lambda_denorm);
|
||||||
exc.Run([&]() {
|
if (ii < 0) return;
|
||||||
int ii = selector_->NextFeature
|
const bst_uint fid = ii;
|
||||||
(i, *model, 0, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm,
|
auto col = page[ii];
|
||||||
param_.reg_lambda_denorm);
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
if (ii < 0) return;
|
double sum_grad = 0.0, sum_hess = 0.0;
|
||||||
const bst_uint fid = ii;
|
for (auto &c : col) {
|
||||||
auto col = page[ii];
|
const GradientPair &p = gpair[c.index * ngroup + gid];
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
if (p.GetHess() < 0.0f) continue;
|
||||||
double sum_grad = 0.0, sum_hess = 0.0;
|
const bst_float v = c.fvalue;
|
||||||
for (auto& c : col) {
|
sum_grad += p.GetGrad() * v;
|
||||||
const GradientPair &p = gpair[c.index * ngroup + gid];
|
sum_hess += p.GetHess() * v * v;
|
||||||
if (p.GetHess() < 0.0f) continue;
|
|
||||||
const bst_float v = c.fvalue;
|
|
||||||
sum_grad += p.GetGrad() * v;
|
|
||||||
sum_hess += p.GetHess() * v * v;
|
|
||||||
}
|
|
||||||
bst_float &w = (*model)[fid][gid];
|
|
||||||
auto dw = static_cast<bst_float>(
|
|
||||||
param_.learning_rate *
|
|
||||||
CoordinateDelta(sum_grad, sum_hess, w, param_.reg_alpha_denorm,
|
|
||||||
param_.reg_lambda_denorm));
|
|
||||||
if (dw == 0.f) continue;
|
|
||||||
w += dw;
|
|
||||||
// update grad values
|
|
||||||
for (auto& c : col) {
|
|
||||||
GradientPair &p = gpair[c.index * ngroup + gid];
|
|
||||||
if (p.GetHess() < 0.0f) continue;
|
|
||||||
p += GradientPair(p.GetHess() * c.fvalue * dw, 0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
bst_float &w = (*model)[fid][gid];
|
||||||
}
|
auto dw = static_cast<bst_float>(
|
||||||
exc.Rethrow();
|
param_.learning_rate * CoordinateDelta(sum_grad, sum_hess, w, param_.reg_alpha_denorm,
|
||||||
|
param_.reg_lambda_denorm));
|
||||||
|
if (dw == 0.f) continue;
|
||||||
|
w += dw;
|
||||||
|
// update grad values
|
||||||
|
for (auto &c : col) {
|
||||||
|
GradientPair &p = gpair[c.index * ngroup + gid];
|
||||||
|
if (p.GetHess() < 0.0f) continue;
|
||||||
|
p += GradientPair(p.GetHess() * c.fvalue * dw, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user