some fix
This commit is contained in:
parent
bbe4957cd2
commit
50af92e29e
@ -75,7 +75,7 @@ class DMatrix:
|
||||
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group) )
|
||||
# set weight of each instances
|
||||
def set_weight(self, weight):
|
||||
xglib.XGDMatrixSetWeight(self.handle, (ctypes.c_uint*len(weight))(*weight), len(weight) )
|
||||
xglib.XGDMatrixSetWeight(self.handle, (ctypes.c_float*len(weight))(*weight), len(weight) )
|
||||
# get label from dmatrix
|
||||
def get_label(self):
|
||||
length = ctypes.c_ulong()
|
||||
|
||||
@ -223,7 +223,7 @@ extern "C"{
|
||||
mats.push_back( static_cast<DMatrix*>(dmats[i]) );
|
||||
names.push_back( std::string( evnames[i]) );
|
||||
}
|
||||
bst->EvalOneIter( iter, mats, names, stdout );
|
||||
bst->EvalOneIter( iter, mats, names, stderr );
|
||||
}
|
||||
const float *XGBoosterPredict( void *handle, void *dmat, size_t *len, int bst_group ){
|
||||
return static_cast<Booster*>(handle)->Pred( *static_cast<DMatrix*>(dmat), len, bst_group );
|
||||
|
||||
@ -328,7 +328,7 @@ namespace xgboost{
|
||||
* \brief adjust base_score
|
||||
*/
|
||||
inline void AdjustBase(void){
|
||||
if (loss_type == 1 || loss_type == 2){
|
||||
if (loss_type == 1 || loss_type == 2|| loss_type == 3){
|
||||
utils::Assert(base_score > 0.0f && base_score < 1.0f, "sigmoid range constrain");
|
||||
base_score = -logf(1.0f / base_score - 1.0f);
|
||||
}
|
||||
|
||||
@ -99,26 +99,44 @@ namespace xgboost{
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief AMS */
|
||||
/*! \brief AMS: also records best threshold */
|
||||
struct EvalAMS : public IEvaluator{
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const DMatrix::Info &info) const {
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
double s_tp = 0.0, b_fp = 0.0;
|
||||
#pragma omp parallel for reduction(+:s_tp,b_fp) schedule( static )
|
||||
std::vector< std::pair<float, unsigned> > rec(ndata);
|
||||
|
||||
#pragma omp parallel for schedule( static )
|
||||
for (unsigned i = 0; i < ndata; ++i){
|
||||
const float wt = info.GetWeight(i);
|
||||
if (preds[i] > 0.5f){
|
||||
if( info.labels[i] > 0.5f ) s_tp += wt;
|
||||
else b_fp += wt;
|
||||
}
|
||||
}
|
||||
rec[i] = std::make_pair( preds[i], i );
|
||||
}
|
||||
std::sort( rec.begin(), rec.end(), CmpFirst );
|
||||
|
||||
const double br = 10.0;
|
||||
return sqrtf( 2*((s_tp+b_fp+br) * log( 1.0 + s_tp/(b_fp+br) ) - s_tp) );
|
||||
double s_tp = 0.0, b_fp = 0.0, tams = 0.0, threshold = 0.0;
|
||||
for (unsigned i = 0; i < ndata-1; ++i){
|
||||
const unsigned ridx = rec[i].second;
|
||||
const float wt = info.weights[ridx];
|
||||
if( info.labels[ridx] > 0.5f ){
|
||||
s_tp += wt;
|
||||
}else{
|
||||
b_fp += wt;
|
||||
}
|
||||
if( rec[i].first != rec[i+1].first ){
|
||||
double ams = sqrtf( 2*((s_tp+b_fp+br) * log( 1.0 + s_tp/(b_fp+br) ) - s_tp) );
|
||||
if( tams < ams ){
|
||||
threshold = (rec[i].first + rec[i+1].first) / 2.0;
|
||||
tams = ams;
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf( stderr, "\tams-thres=%g", threshold );
|
||||
return tams;
|
||||
}
|
||||
virtual const char *Name(void) const{
|
||||
return "ams";
|
||||
}
|
||||
double wtarget;
|
||||
};
|
||||
|
||||
/*! \brief Error */
|
||||
|
||||
@ -50,6 +50,7 @@ namespace xgboost{
|
||||
const static int kLinearSquare = 0;
|
||||
const static int kLogisticNeglik = 1;
|
||||
const static int kLogisticClassify = 2;
|
||||
const static int kLogisticRaw = 3;
|
||||
public:
|
||||
/*! \brief indicate which type we are using */
|
||||
int loss_type;
|
||||
@ -61,6 +62,7 @@ namespace xgboost{
|
||||
*/
|
||||
inline float PredTransform(float x){
|
||||
switch (loss_type){
|
||||
case kLogisticRaw:
|
||||
case kLinearSquare: return x;
|
||||
case kLogisticClassify:
|
||||
case kLogisticNeglik: return 1.0f / (1.0f + expf(-x));
|
||||
@ -77,6 +79,7 @@ namespace xgboost{
|
||||
inline float FirstOrderGradient(float predt, float label) const{
|
||||
switch (loss_type){
|
||||
case kLinearSquare: return predt - label;
|
||||
case kLogisticRaw: predt = 1.0f / (1.0f + expf(-predt));
|
||||
case kLogisticClassify:
|
||||
case kLogisticNeglik: return predt - label;
|
||||
default: utils::Error("unknown loss_type"); return 0.0f;
|
||||
@ -91,6 +94,7 @@ namespace xgboost{
|
||||
inline float SecondOrderGradient(float predt, float label) const{
|
||||
switch (loss_type){
|
||||
case kLinearSquare: return 1.0f;
|
||||
case kLogisticRaw: predt = 1.0f / (1.0f + expf(-predt));
|
||||
case kLogisticClassify:
|
||||
case kLogisticNeglik: return predt * (1 - predt);
|
||||
default: utils::Error("unknown loss_type"); return 0.0f;
|
||||
|
||||
@ -20,6 +20,7 @@ namespace xgboost{
|
||||
virtual ~RegressionObj(){}
|
||||
virtual void SetParam(const char *name, const char *val){
|
||||
if( !strcmp( "loss_type", name ) ) loss.loss_type = atoi( val );
|
||||
if( !strcmp( "scale_pos_weight", name ) ) scale_pos_weight = (float)atof( val );
|
||||
}
|
||||
virtual void GetGradient(const std::vector<float>& preds,
|
||||
const DMatrix::Info &info,
|
||||
@ -33,13 +34,16 @@ namespace xgboost{
|
||||
#pragma omp parallel for schedule( static )
|
||||
for (unsigned j = 0; j < ndata; ++j){
|
||||
float p = loss.PredTransform(preds[j]);
|
||||
grad[j] = loss.FirstOrderGradient(p, info.labels[j]) * info.GetWeight(j);
|
||||
hess[j] = loss.SecondOrderGradient(p, info.labels[j]) * info.GetWeight(j);
|
||||
float w = info.GetWeight(j);
|
||||
if( info.labels[j] == 1.0f ) w *= scale_pos_weight;
|
||||
grad[j] = loss.FirstOrderGradient(p, info.labels[j]) * w;
|
||||
hess[j] = loss.SecondOrderGradient(p, info.labels[j]) * w;
|
||||
}
|
||||
}
|
||||
virtual const char* DefaultEvalMetric(void) {
|
||||
if( loss.loss_type == LossType::kLogisticClassify ) return "error";
|
||||
else return "rmse";
|
||||
if( loss.loss_type == LossType::kLogisticRaw ) return "auc";
|
||||
return "rmse";
|
||||
}
|
||||
virtual void PredTransform(std::vector<float> &preds){
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
@ -49,6 +53,7 @@ namespace xgboost{
|
||||
}
|
||||
}
|
||||
private:
|
||||
float scale_pos_weight;
|
||||
LossType loss;
|
||||
};
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user