parent
d45cf240a9
commit
a33fa05bda
@ -19,6 +19,11 @@ class GPUBuilder {
|
|||||||
void Init(const TrainParam ¶m);
|
void Init(const TrainParam ¶m);
|
||||||
~GPUBuilder();
|
~GPUBuilder();
|
||||||
|
|
||||||
|
void UpdateParam(const TrainParam ¶m)
|
||||||
|
{
|
||||||
|
this->param = param;
|
||||||
|
}
|
||||||
|
|
||||||
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
void Update(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat,
|
||||||
RegTree *p_tree);
|
RegTree *p_tree);
|
||||||
|
|
||||||
|
|||||||
@ -27,6 +27,7 @@ template <typename TStats> class GPUMaker : public TreeUpdater {
|
|||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param.learning_rate;
|
float lr = param.learning_rate;
|
||||||
param.learning_rate = lr / trees.size();
|
param.learning_rate = lr / trees.size();
|
||||||
|
builder.UpdateParam(param);
|
||||||
// build tree
|
// build tree
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
builder.Update(gpair, dmat, trees[i]);
|
builder.Update(gpair, dmat, trees[i]);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user