[FIX] change evaluation to more precision
This commit is contained in:
parent
67fbf8d264
commit
fd173e260f
@ -158,23 +158,24 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
LOG(CONSOLE) << "Partial load option on npart=" << npart;
|
LOG(CONSOLE) << "Partial load option on npart=" << npart;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy handling of binary data loading
|
// legacy handling of binary data loading
|
||||||
if (file_format == "auto" && !load_row_split) {
|
if (file_format == "auto" && !load_row_split) {
|
||||||
int magic;
|
int magic;
|
||||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
|
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
||||||
common::PeekableInStream is(fi.get());
|
if (fi.get() != nullptr) {
|
||||||
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic) &&
|
common::PeekableInStream is(fi.get());
|
||||||
magic == data::SimpleCSRSource::kMagic) {
|
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic) &&
|
||||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
magic == data::SimpleCSRSource::kMagic) {
|
||||||
source->LoadBinary(&is);
|
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||||
DMatrix* dmat = DMatrix::Create(std::move(source), cache_file);
|
source->LoadBinary(&is);
|
||||||
if (!silent) {
|
DMatrix* dmat = DMatrix::Create(std::move(source), cache_file);
|
||||||
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "
|
if (!silent) {
|
||||||
<< dmat->info().num_nonzero << " entries loaded from " << uri;
|
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "
|
||||||
}
|
<< dmat->info().num_nonzero << " entries loaded from " << uri;
|
||||||
return dmat;
|
}
|
||||||
}
|
return dmat;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ftype = file_format;
|
std::string ftype = file_format;
|
||||||
|
|||||||
@ -28,15 +28,15 @@ struct EvalEWiseBase : public Metric {
|
|||||||
CHECK_EQ(preds.size(), info.labels.size())
|
CHECK_EQ(preds.size(), info.labels.size())
|
||||||
<< "label and prediction size not match, "
|
<< "label and prediction size not match, "
|
||||||
<< "hint: use merror or mlogloss for multi-class classification";
|
<< "hint: use merror or mlogloss for multi-class classification";
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
|
const omp_ulong ndata = static_cast<omp_ulong>(info.labels.size());
|
||||||
float sum = 0.0, wsum = 0.0;
|
double sum = 0.0, wsum = 0.0;
|
||||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||||
const float wt = info.GetWeight(i);
|
const float wt = info.GetWeight(i);
|
||||||
sum += Derived::EvalRow(info.labels[i], preds[i]) * wt;
|
sum += Derived::EvalRow(info.labels[i], preds[i]) * wt;
|
||||||
wsum += wt;
|
wsum += wt;
|
||||||
}
|
}
|
||||||
float dat[2]; dat[0] = sum, dat[1] = wsum;
|
double dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||||
if (distributed) {
|
if (distributed) {
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user