quick fix
This commit is contained in:
parent
1e23af2adc
commit
cd9c81be91
@ -26,10 +26,11 @@ struct LinearModel {
|
|||||||
int reserved[16];
|
int reserved[16];
|
||||||
// constructor
|
// constructor
|
||||||
ModelParam(void) {
|
ModelParam(void) {
|
||||||
|
memset(this, 0, sizeof(ModelParam));
|
||||||
base_score = 0.5f;
|
base_score = 0.5f;
|
||||||
num_feature = 0;
|
num_feature = 0;
|
||||||
loss_type = 1;
|
loss_type = 1;
|
||||||
std::memset(reserved, 0, sizeof(reserved));
|
num_feature = 0;
|
||||||
}
|
}
|
||||||
// initialize base score
|
// initialize base score
|
||||||
inline void InitBaseScore(void) {
|
inline void InitBaseScore(void) {
|
||||||
@ -119,7 +120,7 @@ struct LinearModel {
|
|||||||
}
|
}
|
||||||
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
||||||
}
|
}
|
||||||
inline void Save(rabit::IStream &fo, const float *wptr = NULL) const {
|
inline void Save(rabit::IStream &fo, const float *wptr = NULL) {
|
||||||
fo.Write(¶m, sizeof(param));
|
fo.Write(¶m, sizeof(param));
|
||||||
if (wptr == NULL) wptr = weight;
|
if (wptr == NULL) wptr = weight;
|
||||||
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
|
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
|
||||||
|
|||||||
@ -226,7 +226,7 @@ class LBFGSSolver {
|
|||||||
const size_t num_dim = gstate.num_dim;
|
const size_t num_dim = gstate.num_dim;
|
||||||
const DType *gsub = grad + range_begin_;
|
const DType *gsub = grad + range_begin_;
|
||||||
const size_t nsub = range_end_ - range_begin_;
|
const size_t nsub = range_end_ - range_begin_;
|
||||||
double vdot;
|
double vdot = 0.0;
|
||||||
if (n != 0) {
|
if (n != 0) {
|
||||||
// hist[m + n - 1] stores old gradient
|
// hist[m + n - 1] stores old gradient
|
||||||
Minus(hist[m + n - 1], gsub, hist[m + n - 1], nsub);
|
Minus(hist[m + n - 1], gsub, hist[m + n - 1], nsub);
|
||||||
@ -242,15 +242,19 @@ class LBFGSSolver {
|
|||||||
idxset.push_back(std::make_pair(m + j, 2 * m));
|
idxset.push_back(std::make_pair(m + j, 2 * m));
|
||||||
idxset.push_back(std::make_pair(m + j, m + n - 1));
|
idxset.push_back(std::make_pair(m + j, m + n - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate dot products
|
// calculate dot products
|
||||||
std::vector<double> tmp(idxset.size());
|
std::vector<double> tmp(idxset.size());
|
||||||
for (size_t i = 0; i < tmp.size(); ++i) {
|
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||||
tmp[i] = hist.CalcDot(idxset[i].first, idxset[i].second);
|
tmp[i] = hist.CalcDot(idxset[i].first, idxset[i].second);
|
||||||
}
|
}
|
||||||
|
|
||||||
rabit::Allreduce<rabit::op::Sum>(BeginPtr(tmp), tmp.size());
|
rabit::Allreduce<rabit::op::Sum>(BeginPtr(tmp), tmp.size());
|
||||||
|
|
||||||
for (size_t i = 0; i < tmp.size(); ++i) {
|
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||||
gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i];
|
gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// BFGS steps, use vector-free update
|
// BFGS steps, use vector-free update
|
||||||
// parameterize vector using basis in hist
|
// parameterize vector using basis in hist
|
||||||
std::vector<double> alpha(n);
|
std::vector<double> alpha(n);
|
||||||
@ -264,7 +268,7 @@ class LBFGSSolver {
|
|||||||
}
|
}
|
||||||
alpha[j] = vsum / gstate.DotBuf(j, m + j);
|
alpha[j] = vsum / gstate.DotBuf(j, m + j);
|
||||||
delta[m + j] = delta[m + j] - alpha[j];
|
delta[m + j] = delta[m + j] - alpha[j];
|
||||||
}
|
}
|
||||||
// scale
|
// scale
|
||||||
double scale = gstate.DotBuf(n - 1, m + n - 1) /
|
double scale = gstate.DotBuf(n - 1, m + n - 1) /
|
||||||
gstate.DotBuf(m + n - 1, m + n - 1);
|
gstate.DotBuf(m + n - 1, m + n - 1);
|
||||||
@ -280,6 +284,7 @@ class LBFGSSolver {
|
|||||||
double beta = vsum / gstate.DotBuf(j, m + j);
|
double beta = vsum / gstate.DotBuf(j, m + j);
|
||||||
delta[j] = delta[j] + (alpha[j] - beta);
|
delta[j] = delta[j] + (alpha[j] - beta);
|
||||||
}
|
}
|
||||||
|
|
||||||
// set all to zero
|
// set all to zero
|
||||||
std::fill(dir, dir + num_dim, 0.0f);
|
std::fill(dir, dir + num_dim, 0.0f);
|
||||||
DType *dirsub = dir + range_begin_;
|
DType *dirsub = dir + range_begin_;
|
||||||
@ -292,10 +297,11 @@ class LBFGSSolver {
|
|||||||
}
|
}
|
||||||
FixDirL1Sign(dirsub, hist[2 * m], nsub);
|
FixDirL1Sign(dirsub, hist[2 * m], nsub);
|
||||||
vdot = -Dot(dirsub, hist[2 * m], nsub);
|
vdot = -Dot(dirsub, hist[2 * m], nsub);
|
||||||
|
|
||||||
// allreduce to get full direction
|
// allreduce to get full direction
|
||||||
rabit::Allreduce<rabit::op::Sum>(dir, num_dim);
|
rabit::Allreduce<rabit::op::Sum>(dir, num_dim);
|
||||||
rabit::Allreduce<rabit::op::Sum>(&vdot, 1);
|
rabit::Allreduce<rabit::op::Sum>(&vdot, 1);
|
||||||
} else {
|
} else {
|
||||||
SetL1Dir(dir, grad, weight, num_dim);
|
SetL1Dir(dir, grad, weight, num_dim);
|
||||||
vdot = -Dot(dir, dir, num_dim);
|
vdot = -Dot(dir, dir, num_dim);
|
||||||
}
|
}
|
||||||
@ -483,6 +489,7 @@ class LBFGSSolver {
|
|||||||
num_iteration = 0;
|
num_iteration = 0;
|
||||||
num_dim = 0;
|
num_dim = 0;
|
||||||
old_objval = 0.0;
|
old_objval = 0.0;
|
||||||
|
offset_ = 0;
|
||||||
}
|
}
|
||||||
~GlobalState(void) {
|
~GlobalState(void) {
|
||||||
if (grad != NULL) {
|
if (grad != NULL) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user