refactor config
This commit is contained in:
parent
dafa44753a
commit
3589e8252f
@ -57,7 +57,6 @@ class IObjFunction{
|
|||||||
return base_score;
|
return base_score;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace learner
|
} // namespace learner
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
@ -68,10 +67,10 @@ namespace xgboost {
|
|||||||
namespace learner {
|
namespace learner {
|
||||||
/*! \brief factory funciton to create objective function by name */
|
/*! \brief factory funciton to create objective function by name */
|
||||||
inline IObjFunction* CreateObjFunction(const char *name) {
|
inline IObjFunction* CreateObjFunction(const char *name) {
|
||||||
if (!strcmp("reg:linear", name)) return new RegLossObj( LossType::kLinearSquare );
|
if (!strcmp("reg:linear", name)) return new RegLossObj(LossType::kLinearSquare);
|
||||||
if (!strcmp("reg:logistic", name)) return new RegLossObj( LossType::kLogisticNeglik );
|
if (!strcmp("reg:logistic", name)) return new RegLossObj(LossType::kLogisticNeglik);
|
||||||
if (!strcmp("binary:logistic", name)) return new RegLossObj( LossType::kLogisticClassify );
|
if (!strcmp("binary:logistic", name)) return new RegLossObj(LossType::kLogisticClassify);
|
||||||
if (!strcmp("binary:logitraw", name)) return new RegLossObj( LossType::kLogisticRaw );
|
if (!strcmp("binary:logitraw", name)) return new RegLossObj(LossType::kLogisticRaw);
|
||||||
utils::Error("unknown objective function type: %s", name);
|
utils::Error("unknown objective function type: %s", name);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
FMatrix &fmat,
|
FMatrix &fmat,
|
||||||
const std::vector<unsigned> &root_index,
|
const std::vector<unsigned> &root_index,
|
||||||
const std::vector<RegTree*> &trees) {
|
const std::vector<RegTree*> &trees) {
|
||||||
fmat.InitColAccess();
|
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
Builder builder(param);
|
Builder builder(param);
|
||||||
builder.Update(gpair, fmat, root_index, trees[i]);
|
builder.Update(gpair, fmat, root_index, trees[i]);
|
||||||
|
|||||||
391
utils/config.h
391
utils/config.h
@ -1,219 +1,196 @@
|
|||||||
#ifndef XGBOOST_CONFIG_H
|
#ifndef XGBOOST_UTILS_CONFIG_H_
|
||||||
#define XGBOOST_CONFIG_H
|
#define XGBOOST_UTILS_CONFIG_H_
|
||||||
/*!
|
/*!
|
||||||
* \file xgboost_config.h
|
* \file config.h
|
||||||
* \brief helper class to load in configures from file
|
* \brief helper class to load in configures from file
|
||||||
* \author Tianqi Chen: tianqi.tchen@gmail.com
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#define _CRT_SECURE_NO_WARNINGS
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "xgboost_utils.h"
|
#include <istream>
|
||||||
#include <vector>
|
#include <fstream>
|
||||||
|
#include "./utils.h"
|
||||||
|
|
||||||
namespace xgboost{
|
namespace xgboost {
|
||||||
namespace utils{
|
namespace utils {
|
||||||
/*!
|
/*!
|
||||||
* \brief an iterator that iterates over a configure file and gets the configures
|
* \brief base implementation of config reader
|
||||||
*/
|
*/
|
||||||
class ConfigIterator{
|
class ConfigReaderBase {
|
||||||
public:
|
public:
|
||||||
/*!
|
/*!
|
||||||
* \brief constructor
|
* \brief get current name, called after Next returns true
|
||||||
* \param fname name of configure file
|
* \return current parameter name
|
||||||
*/
|
*/
|
||||||
ConfigIterator(const char *fname){
|
inline const char *name(void) const {
|
||||||
fi = FopenCheck(fname, "r");
|
return s_name;
|
||||||
ch_buf = fgetc(fi);
|
}
|
||||||
}
|
/*!
|
||||||
/*! \brief destructor */
|
* \brief get current value, called after Next returns true
|
||||||
~ConfigIterator(){
|
* \return current parameter value
|
||||||
fclose(fi);
|
*/
|
||||||
}
|
inline const char *val(void) const {
|
||||||
/*!
|
return s_val;
|
||||||
* \brief get current name, called after Next returns true
|
}
|
||||||
* \return current parameter name
|
/*!
|
||||||
*/
|
* \brief move iterator to next position
|
||||||
inline const char *name(void)const{
|
* \return true if there is value in next position
|
||||||
return s_name;
|
*/
|
||||||
}
|
inline bool Next(void) {
|
||||||
/*!
|
while (!this->IsEnd()) {
|
||||||
* \brief get current value, called after Next returns true
|
GetNextToken(s_name);
|
||||||
* \return current parameter value
|
if (s_name[0] == '=') return false;
|
||||||
*/
|
if (GetNextToken( s_buf ) || s_buf[0] != '=') return false;
|
||||||
inline const char *val(void) const{
|
if (GetNextToken( s_val ) || s_val[0] == '=') return false;
|
||||||
return s_val;
|
return true;
|
||||||
}
|
}
|
||||||
/*!
|
return false;
|
||||||
* \brief move iterator to next position
|
}
|
||||||
* \return true if there is value in next position
|
// called before usage
|
||||||
*/
|
inline void Init(void) {
|
||||||
inline bool Next(void){
|
ch_buf = this->GetChar();
|
||||||
while (!feof(fi)){
|
}
|
||||||
GetNextToken(s_name);
|
|
||||||
if (s_name[0] == '=') return false;
|
|
||||||
if (GetNextToken(s_buf) || s_buf[0] != '=') return false;
|
|
||||||
if (GetNextToken(s_val) || s_val[0] == '=') return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
private:
|
|
||||||
FILE *fi;
|
|
||||||
char ch_buf;
|
|
||||||
char s_name[256], s_val[256], s_buf[246];
|
|
||||||
|
|
||||||
inline void SkipLine(){
|
protected:
|
||||||
do{
|
/*!
|
||||||
ch_buf = fgetc(fi);
|
* \brief to be implemented by subclass,
|
||||||
} while (ch_buf != EOF && ch_buf != '\n' && ch_buf != '\r');
|
* get next token, return EOF if end of file
|
||||||
}
|
*/
|
||||||
|
virtual char GetChar(void) = 0;
|
||||||
|
/*! \brief to be implemented by child, check if end of stream */
|
||||||
|
virtual bool IsEnd(void) = 0;
|
||||||
|
|
||||||
inline void ParseStr(char tok[]){
|
private:
|
||||||
int i = 0;
|
char ch_buf;
|
||||||
while ((ch_buf = fgetc(fi)) != EOF){
|
char s_name[100000], s_val[100000], s_buf[100000];
|
||||||
switch (ch_buf){
|
|
||||||
case '\\': tok[i++] = fgetc(fi); break;
|
|
||||||
case '\"': tok[i++] = '\0';
|
|
||||||
return;
|
|
||||||
case '\r':
|
|
||||||
case '\n': Error("unterminated string"); break;
|
|
||||||
default: tok[i++] = ch_buf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Error("unterminated string");
|
|
||||||
}
|
|
||||||
// return newline
|
|
||||||
inline bool GetNextToken(char tok[]){
|
|
||||||
int i = 0;
|
|
||||||
bool new_line = false;
|
|
||||||
while (ch_buf != EOF){
|
|
||||||
switch (ch_buf){
|
|
||||||
case '#': SkipLine(); new_line = true; break;
|
|
||||||
case '\"':
|
|
||||||
if (i == 0){
|
|
||||||
ParseStr(tok); ch_buf = fgetc(fi); return new_line;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
Error("token followed directly by string");
|
|
||||||
}
|
|
||||||
case '=':
|
|
||||||
if (i == 0) {
|
|
||||||
ch_buf = fgetc(fi);
|
|
||||||
tok[0] = '=';
|
|
||||||
tok[1] = '\0';
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
tok[i] = '\0';
|
|
||||||
}
|
|
||||||
return new_line;
|
|
||||||
case '\r':
|
|
||||||
case '\n':
|
|
||||||
if (i == 0) new_line = true;
|
|
||||||
case '\t':
|
|
||||||
case ' ':
|
|
||||||
ch_buf = fgetc(fi);
|
|
||||||
if (i > 0){
|
|
||||||
tok[i] = '\0';
|
|
||||||
return new_line;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
tok[i++] = ch_buf;
|
|
||||||
ch_buf = fgetc(fi);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace utils{
|
inline void SkipLine(void) {
|
||||||
/*!
|
do {
|
||||||
* \brief a class that save parameter configurations
|
ch_buf = this->GetChar();
|
||||||
* temporally and allows to get them out later
|
} while (ch_buf != EOF && ch_buf != '\n' && ch_buf != '\r');
|
||||||
* there are two kinds of priority in ConfigSaver
|
}
|
||||||
*/
|
|
||||||
class ConfigSaver{
|
inline void ParseStr(char tok[]) {
|
||||||
public:
|
int i = 0;
|
||||||
/*! \brief constructor */
|
while ((ch_buf = this->GetChar()) != EOF) {
|
||||||
ConfigSaver(void){ idx = 0; }
|
switch (ch_buf) {
|
||||||
/*! \brief clear all saves */
|
case '\\': tok[i++] = this->GetChar(); break;
|
||||||
inline void Clear(void){
|
case '\"': tok[i++] = '\0'; return;
|
||||||
idx = 0;
|
case '\r':
|
||||||
names.clear(); values.clear();
|
case '\n': Error("ConfigReader: unterminated string");
|
||||||
names_high.clear(); values_high.clear();
|
default: tok[i++] = ch_buf;
|
||||||
}
|
}
|
||||||
/*!
|
}
|
||||||
* \brief push back a parameter setting
|
Error("ConfigReader: unterminated string");
|
||||||
* \param name name of parameter
|
}
|
||||||
* \param val value of parameter
|
inline void ParseStrML(char tok[]) {
|
||||||
* \param priority whether the setting has higher priority: high priority occurs
|
int i = 0;
|
||||||
* latter when read from ConfigSaver, and can overwrite existing settings
|
while ((ch_buf = this->GetChar()) != EOF) {
|
||||||
*/
|
switch (ch_buf) {
|
||||||
inline void PushBack(const char *name, const char *val, int priority = 0){
|
case '\\': tok[i++] = this->GetChar(); break;
|
||||||
if (priority == 0){
|
case '\'': tok[i++] = '\0'; return;
|
||||||
names.push_back(std::string(name));
|
default: tok[i++] = ch_buf;
|
||||||
values.push_back(std::string(val));
|
}
|
||||||
}
|
}
|
||||||
else{
|
Error("unterminated string");
|
||||||
names_high.push_back(std::string(name));
|
}
|
||||||
values_high.push_back(std::string(val));
|
// return newline
|
||||||
}
|
inline bool GetNextToken(char tok[]) {
|
||||||
}
|
int i = 0;
|
||||||
/*! \brief set pointer to beginning of the ConfigSaver */
|
bool new_line = false;
|
||||||
inline void BeforeFirst(void){
|
while (ch_buf != EOF) {
|
||||||
idx = 0;
|
switch (ch_buf) {
|
||||||
}
|
case '#' : SkipLine(); new_line = true; break;
|
||||||
/*!
|
case '\"':
|
||||||
* \brief move iterator to next position
|
if (i == 0) {
|
||||||
* \return true if there is value in next position
|
ParseStr(tok); ch_buf = this->GetChar(); return new_line;
|
||||||
*/
|
} else {
|
||||||
inline bool Next(void){
|
Error("ConfigReader: token followed directly by string");
|
||||||
if (idx >= names.size() + names_high.size()){
|
}
|
||||||
return false;
|
case '\'':
|
||||||
}
|
if (i == 0) {
|
||||||
idx++;
|
ParseStrML( tok ); ch_buf = this->GetChar(); return new_line;
|
||||||
return true;
|
} else {
|
||||||
}
|
Error("ConfigReader: token followed directly by string");
|
||||||
/*!
|
}
|
||||||
* \brief get current name, called after Next returns true
|
case '=':
|
||||||
* \return current parameter name
|
if (i == 0) {
|
||||||
*/
|
ch_buf = this->GetChar();
|
||||||
inline const char *name(void) const{
|
tok[0] = '=';
|
||||||
Assert(idx > 0, "can't call name before first");
|
tok[1] = '\0';
|
||||||
size_t i = idx - 1;
|
} else {
|
||||||
if (i >= names.size()){
|
tok[i] = '\0';
|
||||||
return names_high[i - names.size()].c_str();
|
}
|
||||||
}
|
return new_line;
|
||||||
else{
|
case '\r':
|
||||||
return names[i].c_str();
|
case '\n':
|
||||||
}
|
if (i == 0) new_line = true;
|
||||||
}
|
case '\t':
|
||||||
/*!
|
case ' ' :
|
||||||
* \brief get current value, called after Next returns true
|
ch_buf = this->GetChar();
|
||||||
* \return current parameter value
|
if (i > 0) {
|
||||||
*/
|
tok[i] = '\0';
|
||||||
inline const char *val(void) const{
|
return new_line;
|
||||||
Assert(idx > 0, "can't call name before first");
|
}
|
||||||
size_t i = idx - 1;
|
break;
|
||||||
if (i >= values.size()){
|
default:
|
||||||
return values_high[i - values.size()].c_str();
|
tok[i++] = ch_buf;
|
||||||
}
|
ch_buf = this->GetChar();
|
||||||
else{
|
break;
|
||||||
return values[i].c_str();
|
}
|
||||||
}
|
}
|
||||||
}
|
return true;
|
||||||
private:
|
}
|
||||||
std::vector<std::string> names;
|
|
||||||
std::vector<std::string> values;
|
|
||||||
std::vector<std::string> names_high;
|
|
||||||
std::vector<std::string> values_high;
|
|
||||||
size_t idx;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
#endif
|
/*!
|
||||||
|
* \brief an iterator use stream base, allows use all types of istream
|
||||||
|
*/
|
||||||
|
class ConfigStreamReader: public ConfigReaderBase {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief constructor
|
||||||
|
* \param istream input stream
|
||||||
|
*/
|
||||||
|
explicit ConfigStreamReader(std::istream &fin) : fin(fin) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual char GetChar(void) {
|
||||||
|
return fin.get();
|
||||||
|
}
|
||||||
|
/*! \brief to be implemented by child, check if end of stream */
|
||||||
|
virtual bool IsEnd(void) {
|
||||||
|
return fin.eof();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::istream &fin;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief an iterator that iterates over a configure file and gets the configures
|
||||||
|
*/
|
||||||
|
class ConfigIterator: public ConfigStreamReader {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief constructor
|
||||||
|
* \param fname name of configure file
|
||||||
|
*/
|
||||||
|
explicit ConfigIterator(const char *fname) : ConfigStreamReader(fi) {
|
||||||
|
fi.open(fname);
|
||||||
|
if (fi.fail()) {
|
||||||
|
utils::Error("cannot open file %s", fname);
|
||||||
|
}
|
||||||
|
ConfigReaderBase::Init();
|
||||||
|
}
|
||||||
|
/*! \brief destructor */
|
||||||
|
~ConfigIterator(void) {
|
||||||
|
fi.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::ifstream fi;
|
||||||
|
};
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_UTILS_CONFIG_H_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user