refactor config

This commit is contained in:
tqchen 2014-08-15 21:02:33 -07:00
parent dafa44753a
commit 3589e8252f
3 changed files with 189 additions and 213 deletions

View File

@ -57,7 +57,6 @@ class IObjFunction{
return base_score; return base_score;
} }
}; };
} // namespace learner } // namespace learner
} // namespace xgboost } // namespace xgboost
@ -68,10 +67,10 @@ namespace xgboost {
namespace learner { namespace learner {
/*! \brief factory funciton to create objective function by name */ /*! \brief factory funciton to create objective function by name */
inline IObjFunction* CreateObjFunction(const char *name) { inline IObjFunction* CreateObjFunction(const char *name) {
if (!strcmp("reg:linear", name)) return new RegLossObj( LossType::kLinearSquare ); if (!strcmp("reg:linear", name)) return new RegLossObj(LossType::kLinearSquare);
if (!strcmp("reg:logistic", name)) return new RegLossObj( LossType::kLogisticNeglik ); if (!strcmp("reg:logistic", name)) return new RegLossObj(LossType::kLogisticNeglik);
if (!strcmp("binary:logistic", name)) return new RegLossObj( LossType::kLogisticClassify ); if (!strcmp("binary:logistic", name)) return new RegLossObj(LossType::kLogisticClassify);
if (!strcmp("binary:logitraw", name)) return new RegLossObj( LossType::kLogisticRaw ); if (!strcmp("binary:logitraw", name)) return new RegLossObj(LossType::kLogisticRaw);
utils::Error("unknown objective function type: %s", name); utils::Error("unknown objective function type: %s", name);
return NULL; return NULL;
} }

View File

@ -27,7 +27,7 @@ class ColMaker: public IUpdater<FMatrix> {
FMatrix &fmat, FMatrix &fmat,
const std::vector<unsigned> &root_index, const std::vector<unsigned> &root_index,
const std::vector<RegTree*> &trees) { const std::vector<RegTree*> &trees) {
fmat.InitColAccess();
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param); Builder builder(param);
builder.Update(gpair, fmat, root_index, trees[i]); builder.Update(gpair, fmat, root_index, trees[i]);

View File

@ -1,219 +1,196 @@
#ifndef XGBOOST_CONFIG_H #ifndef XGBOOST_UTILS_CONFIG_H_
#define XGBOOST_CONFIG_H #define XGBOOST_UTILS_CONFIG_H_
/*! /*!
* \file xgboost_config.h * \file config.h
* \brief helper class to load in configures from file * \brief helper class to load in configures from file
* \author Tianqi Chen: tianqi.tchen@gmail.com * \author Tianqi Chen
*/ */
#define _CRT_SECURE_NO_WARNINGS
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <string> #include <string>
#include "xgboost_utils.h" #include <istream>
#include <vector> #include <fstream>
#include "./utils.h"
namespace xgboost{ namespace xgboost {
namespace utils{ namespace utils {
/*! /*!
* \brief an iterator that iterates over a configure file and gets the configures * \brief base implementation of config reader
*/ */
class ConfigIterator{ class ConfigReaderBase {
public: public:
/*! /*!
* \brief constructor * \brief get current name, called after Next returns true
* \param fname name of configure file * \return current parameter name
*/ */
ConfigIterator(const char *fname){ inline const char *name(void) const {
fi = FopenCheck(fname, "r"); return s_name;
ch_buf = fgetc(fi); }
} /*!
/*! \brief destructor */ * \brief get current value, called after Next returns true
~ConfigIterator(){ * \return current parameter value
fclose(fi); */
} inline const char *val(void) const {
/*! return s_val;
* \brief get current name, called after Next returns true }
* \return current parameter name /*!
*/ * \brief move iterator to next position
inline const char *name(void)const{ * \return true if there is value in next position
return s_name; */
} inline bool Next(void) {
/*! while (!this->IsEnd()) {
* \brief get current value, called after Next returns true GetNextToken(s_name);
* \return current parameter value if (s_name[0] == '=') return false;
*/ if (GetNextToken( s_buf ) || s_buf[0] != '=') return false;
inline const char *val(void) const{ if (GetNextToken( s_val ) || s_val[0] == '=') return false;
return s_val; return true;
} }
/*! return false;
* \brief move iterator to next position }
* \return true if there is value in next position // called before usage
*/ inline void Init(void) {
inline bool Next(void){ ch_buf = this->GetChar();
while (!feof(fi)){ }
GetNextToken(s_name);
if (s_name[0] == '=') return false;
if (GetNextToken(s_buf) || s_buf[0] != '=') return false;
if (GetNextToken(s_val) || s_val[0] == '=') return false;
return true;
}
return false;
}
private:
FILE *fi;
char ch_buf;
char s_name[256], s_val[256], s_buf[246];
inline void SkipLine(){ protected:
do{ /*!
ch_buf = fgetc(fi); * \brief to be implemented by subclass,
} while (ch_buf != EOF && ch_buf != '\n' && ch_buf != '\r'); * get next token, return EOF if end of file
} */
virtual char GetChar(void) = 0;
/*! \brief to be implemented by child, check if end of stream */
virtual bool IsEnd(void) = 0;
inline void ParseStr(char tok[]){ private:
int i = 0; char ch_buf;
while ((ch_buf = fgetc(fi)) != EOF){ char s_name[100000], s_val[100000], s_buf[100000];
switch (ch_buf){
case '\\': tok[i++] = fgetc(fi); break;
case '\"': tok[i++] = '\0';
return;
case '\r':
case '\n': Error("unterminated string"); break;
default: tok[i++] = ch_buf;
}
}
Error("unterminated string");
}
// return newline
inline bool GetNextToken(char tok[]){
int i = 0;
bool new_line = false;
while (ch_buf != EOF){
switch (ch_buf){
case '#': SkipLine(); new_line = true; break;
case '\"':
if (i == 0){
ParseStr(tok); ch_buf = fgetc(fi); return new_line;
}
else{
Error("token followed directly by string");
}
case '=':
if (i == 0) {
ch_buf = fgetc(fi);
tok[0] = '=';
tok[1] = '\0';
}
else{
tok[i] = '\0';
}
return new_line;
case '\r':
case '\n':
if (i == 0) new_line = true;
case '\t':
case ' ':
ch_buf = fgetc(fi);
if (i > 0){
tok[i] = '\0';
return new_line;
}
break;
default:
tok[i++] = ch_buf;
ch_buf = fgetc(fi);
break;
}
}
return true;
}
};
};
namespace utils{ inline void SkipLine(void) {
/*! do {
* \brief a class that save parameter configurations ch_buf = this->GetChar();
* temporally and allows to get them out later } while (ch_buf != EOF && ch_buf != '\n' && ch_buf != '\r');
* there are two kinds of priority in ConfigSaver }
*/
class ConfigSaver{ inline void ParseStr(char tok[]) {
public: int i = 0;
/*! \brief constructor */ while ((ch_buf = this->GetChar()) != EOF) {
ConfigSaver(void){ idx = 0; } switch (ch_buf) {
/*! \brief clear all saves */ case '\\': tok[i++] = this->GetChar(); break;
inline void Clear(void){ case '\"': tok[i++] = '\0'; return;
idx = 0; case '\r':
names.clear(); values.clear(); case '\n': Error("ConfigReader: unterminated string");
names_high.clear(); values_high.clear(); default: tok[i++] = ch_buf;
} }
/*! }
* \brief push back a parameter setting Error("ConfigReader: unterminated string");
* \param name name of parameter }
* \param val value of parameter inline void ParseStrML(char tok[]) {
* \param priority whether the setting has higher priority: high priority occurs int i = 0;
* latter when read from ConfigSaver, and can overwrite existing settings while ((ch_buf = this->GetChar()) != EOF) {
*/ switch (ch_buf) {
inline void PushBack(const char *name, const char *val, int priority = 0){ case '\\': tok[i++] = this->GetChar(); break;
if (priority == 0){ case '\'': tok[i++] = '\0'; return;
names.push_back(std::string(name)); default: tok[i++] = ch_buf;
values.push_back(std::string(val)); }
} }
else{ Error("unterminated string");
names_high.push_back(std::string(name)); }
values_high.push_back(std::string(val)); // return newline
} inline bool GetNextToken(char tok[]) {
} int i = 0;
/*! \brief set pointer to beginning of the ConfigSaver */ bool new_line = false;
inline void BeforeFirst(void){ while (ch_buf != EOF) {
idx = 0; switch (ch_buf) {
} case '#' : SkipLine(); new_line = true; break;
/*! case '\"':
* \brief move iterator to next position if (i == 0) {
* \return true if there is value in next position ParseStr(tok); ch_buf = this->GetChar(); return new_line;
*/ } else {
inline bool Next(void){ Error("ConfigReader: token followed directly by string");
if (idx >= names.size() + names_high.size()){ }
return false; case '\'':
} if (i == 0) {
idx++; ParseStrML( tok ); ch_buf = this->GetChar(); return new_line;
return true; } else {
} Error("ConfigReader: token followed directly by string");
/*! }
* \brief get current name, called after Next returns true case '=':
* \return current parameter name if (i == 0) {
*/ ch_buf = this->GetChar();
inline const char *name(void) const{ tok[0] = '=';
Assert(idx > 0, "can't call name before first"); tok[1] = '\0';
size_t i = idx - 1; } else {
if (i >= names.size()){ tok[i] = '\0';
return names_high[i - names.size()].c_str(); }
} return new_line;
else{ case '\r':
return names[i].c_str(); case '\n':
} if (i == 0) new_line = true;
} case '\t':
/*! case ' ' :
* \brief get current value, called after Next returns true ch_buf = this->GetChar();
* \return current parameter value if (i > 0) {
*/ tok[i] = '\0';
inline const char *val(void) const{ return new_line;
Assert(idx > 0, "can't call name before first"); }
size_t i = idx - 1; break;
if (i >= values.size()){ default:
return values_high[i - values.size()].c_str(); tok[i++] = ch_buf;
} ch_buf = this->GetChar();
else{ break;
return values[i].c_str(); }
} }
} return true;
private: }
std::vector<std::string> names;
std::vector<std::string> values;
std::vector<std::string> names_high;
std::vector<std::string> values_high;
size_t idx;
};
};
}; };
#endif /*!
* \brief an iterator use stream base, allows use all types of istream
*/
class ConfigStreamReader: public ConfigReaderBase {
public:
/*!
* \brief constructor
* \param istream input stream
*/
explicit ConfigStreamReader(std::istream &fin) : fin(fin) {}
protected:
virtual char GetChar(void) {
return fin.get();
}
/*! \brief to be implemented by child, check if end of stream */
virtual bool IsEnd(void) {
return fin.eof();
}
private:
std::istream &fin;
};
/*!
* \brief an iterator that iterates over a configure file and gets the configures
*/
class ConfigIterator: public ConfigStreamReader {
public:
/*!
* \brief constructor
* \param fname name of configure file
*/
explicit ConfigIterator(const char *fname) : ConfigStreamReader(fi) {
fi.open(fname);
if (fi.fail()) {
utils::Error("cannot open file %s", fname);
}
ConfigReaderBase::Init();
}
/*! \brief destructor */
~ConfigIterator(void) {
fi.close();
}
private:
std::ifstream fi;
};
} // namespace utils
} // namespace xgboost
#endif // XGBOOST_UTILS_CONFIG_H_