add changes

This commit is contained in:
tqchen
2014-08-24 17:25:17 -07:00
parent da75f8f1a4
commit 7874c2559b
5 changed files with 52 additions and 28 deletions

View File

@@ -62,10 +62,10 @@ class DMatrixSimple : public DataMatrix {
inline size_t AddRow(const std::vector<SparseBatch::Entry> &feats) {
for (size_t i = 0; i < feats.size(); ++i) {
row_data_.push_back(feats[i]);
info.num_col = std::max(info.num_col, static_cast<size_t>(feats[i].findex+1));
info.info.num_col = std::max(info.info.num_col, static_cast<size_t>(feats[i].findex+1));
}
row_ptr_.push_back(row_ptr_.back() + feats.size());
info.num_row += 1;
info.info.num_row += 1;
return row_ptr_.size() - 2;
}
/*!
@@ -99,19 +99,19 @@ class DMatrixSimple : public DataMatrix {
if (!silent) {
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
info.num_row, info.num_col, row_data_.size(), fname);
info.num_row(), info.num_col(), row_data_.size(), fname);
}
fclose(file);
// try to load in additional file
std::string name = fname;
std::string gname = name + ".group";
if (info.TryLoadGroup(gname.c_str(), silent)) {
utils::Check(info.group_ptr.back() == info.num_row,
utils::Check(info.group_ptr.back() == info.num_row(),
"DMatrix: group data does not match the number of rows in features");
}
std::string wname = name + ".weight";
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
utils::Check(info.weights.size() == info.num_row,
utils::Check(info.weights.size() == info.num_row(),
"DMatrix: weight data does not match the number of rows in features");
}
std::string mname = name + ".base_margin";
@@ -139,7 +139,7 @@ class DMatrixSimple : public DataMatrix {
if (!silent) {
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
info.num_row, info.num_col, row_data_.size(), fname);
info.num_row(), info.num_col(), row_data_.size(), fname);
if (info.group_ptr.size() != 0) {
printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1);
}
@@ -163,7 +163,7 @@ class DMatrixSimple : public DataMatrix {
if (!silent) {
printf("%lux%lu matrix with %lu entries is saved to %s\n",
info.num_row, info.num_col, row_data_.size(), fname);
info.num_row(), info.num_col(), row_data_.size(), fname);
if (info.group_ptr.size() != 0) {
printf("data contains %lu groups\n", info.group_ptr.size()-1);
}