ok
This commit is contained in:
parent
5a472145de
commit
4ed4b08146
@ -15,7 +15,7 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) {
|
void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) {
|
||||||
if (dmat.magic == DMatrixSimple::kMagic){
|
if (dmat.magic == DMatrixSimple::kMagic) {
|
||||||
const DMatrixSimple *p_dmat = static_cast<const DMatrixSimple*>(&dmat);
|
const DMatrixSimple *p_dmat = static_cast<const DMatrixSimple*>(&dmat);
|
||||||
p_dmat->SaveBinary(fname, silent);
|
p_dmat->SaveBinary(fname, silent);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -30,7 +30,7 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent = false, bool savebuff
|
|||||||
* \param fname file name to be savd
|
* \param fname file name to be savd
|
||||||
* \param silent whether print message during saving
|
* \param silent whether print message during saving
|
||||||
*/
|
*/
|
||||||
void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent = false);
|
void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent = false);
|
||||||
|
|
||||||
} // namespace io
|
} // namespace io
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -203,7 +203,7 @@ class BoostLearner {
|
|||||||
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
inline std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||||
return gbm_->DumpModel(fmap, option);
|
return gbm_->DumpModel(fmap, option);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize the objective function and GBM,
|
* \brief initialize the objective function and GBM,
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "../data.h"
|
#include "../data.h"
|
||||||
#include "./objective.h"
|
#include "./objective.h"
|
||||||
@ -254,19 +255,17 @@ class LambdaRankObj : public IObjFunction {
|
|||||||
utils::Check(gptr.size() != 0 && gptr.back() == preds.size(),
|
utils::Check(gptr.size() != 0 && gptr.back() == preds.size(),
|
||||||
"group structure not consistent with #rows");
|
"group structure not consistent with #rows");
|
||||||
const unsigned ngroup = static_cast<unsigned>(gptr.size() - 1);
|
const unsigned ngroup = static_cast<unsigned>(gptr.size() - 1);
|
||||||
|
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
// parall construct, declare random number generator here, so that each
|
// parall construct, declare random number generator here, so that each
|
||||||
// thread use its own random number generator, seed by thread id and current iteration
|
// thread use its own random number generator, seed by thread id and current iteration
|
||||||
random::Random rnd; rnd.Seed(iter* 1111 + omp_get_thread_num());
|
random::Random rnd; rnd.Seed(iter* 1111 + omp_get_thread_num());
|
||||||
std::vector<LambdaPair> pairs;
|
std::vector<LambdaPair> pairs;
|
||||||
std::vector<ListEntry> lst;
|
std::vector<ListEntry> lst;
|
||||||
std::vector< std::pair<float,unsigned> > rec;
|
std::vector< std::pair<float, unsigned> > rec;
|
||||||
|
|
||||||
#pragma omp for schedule(static)
|
#pragma omp for schedule(static)
|
||||||
for (unsigned k = 0; k < ngroup; ++k) {
|
for (unsigned k = 0; k < ngroup; ++k) {
|
||||||
lst.clear(); pairs.clear();
|
lst.clear(); pairs.clear();
|
||||||
for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) {
|
for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) {
|
||||||
lst.push_back(ListEntry(preds[j], info.labels[j], j));
|
lst.push_back(ListEntry(preds[j], info.labels[j], j));
|
||||||
gpair[j] = bst_gpair(0.0f, 0.0f);
|
gpair[j] = bst_gpair(0.0f, 0.0f);
|
||||||
@ -313,8 +312,8 @@ class LambdaRankObj : public IObjFunction {
|
|||||||
float g = loss.FirstOrderGradient(p, 1.0f);
|
float g = loss.FirstOrderGradient(p, 1.0f);
|
||||||
float h = loss.SecondOrderGradient(p, 1.0f);
|
float h = loss.SecondOrderGradient(p, 1.0f);
|
||||||
// accumulate gradient and hessian in both pid, and nid
|
// accumulate gradient and hessian in both pid, and nid
|
||||||
gpair[pos.rindex].grad += g * w;
|
gpair[pos.rindex].grad += g * w;
|
||||||
gpair[pos.rindex].hess += 2.0f * h;
|
gpair[pos.rindex].hess += 2.0f * h;
|
||||||
gpair[neg.rindex].grad -= g * w;
|
gpair[neg.rindex].grad -= g * w;
|
||||||
gpair[neg.rindex].hess += 2.0f * h;
|
gpair[neg.rindex].hess += 2.0f * h;
|
||||||
}
|
}
|
||||||
@ -332,7 +331,7 @@ class LambdaRankObj : public IObjFunction {
|
|||||||
float pred;
|
float pred;
|
||||||
/*! \brief the actual label of the entry */
|
/*! \brief the actual label of the entry */
|
||||||
float label;
|
float label;
|
||||||
/*! \brief row index in the data matrix */
|
/*! \brief row index in the data matrix */
|
||||||
unsigned rindex;
|
unsigned rindex;
|
||||||
// constructor
|
// constructor
|
||||||
ListEntry(float pred, float label, unsigned rindex)
|
ListEntry(float pred, float label, unsigned rindex)
|
||||||
@ -370,14 +369,14 @@ class LambdaRankObj : public IObjFunction {
|
|||||||
// loss function
|
// loss function
|
||||||
LossType loss;
|
LossType loss;
|
||||||
// number of samples peformed for each instance
|
// number of samples peformed for each instance
|
||||||
int num_pairsample;
|
int num_pairsample;
|
||||||
// fix weight of each elements in list
|
// fix weight of each elements in list
|
||||||
float fix_list_weight;
|
float fix_list_weight;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PairwiseRankObj: public LambdaRankObj{
|
class PairwiseRankObj: public LambdaRankObj{
|
||||||
public:
|
public:
|
||||||
virtual ~PairwiseRankObj(void){}
|
virtual ~PairwiseRankObj(void) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
|
virtual void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
|
||||||
@ -402,7 +401,6 @@ class LambdaRankObjNDCG : public LambdaRankObj {
|
|||||||
std::sort(labels.begin(), labels.end(), std::greater<float>());
|
std::sort(labels.begin(), labels.end(), std::greater<float>());
|
||||||
IDCG = CalcDCG(labels);
|
IDCG = CalcDCG(labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IDCG == 0.0) {
|
if (IDCG == 0.0) {
|
||||||
for (size_t i = 0; i < pairs.size(); ++i) {
|
for (size_t i = 0; i < pairs.size(); ++i) {
|
||||||
pairs[i].weight = 0.0f;
|
pairs[i].weight = 0.0f;
|
||||||
@ -412,13 +410,15 @@ class LambdaRankObjNDCG : public LambdaRankObj {
|
|||||||
for (size_t i = 0; i < pairs.size(); ++i) {
|
for (size_t i = 0; i < pairs.size(); ++i) {
|
||||||
unsigned pos_idx = pairs[i].pos_index;
|
unsigned pos_idx = pairs[i].pos_index;
|
||||||
unsigned neg_idx = pairs[i].neg_index;
|
unsigned neg_idx = pairs[i].neg_index;
|
||||||
float pos_loginv = 1.0f / logf(pos_idx+2.0f);
|
float pos_loginv = 1.0f / logf(pos_idx + 2.0f);
|
||||||
float neg_loginv = 1.0f / logf(neg_idx+2.0f);
|
float neg_loginv = 1.0f / logf(neg_idx + 2.0f);
|
||||||
int pos_label = static_cast<int>(sorted_list[pos_idx].label);
|
int pos_label = static_cast<int>(sorted_list[pos_idx].label);
|
||||||
int neg_label = static_cast<int>(sorted_list[neg_idx].label);
|
int neg_label = static_cast<int>(sorted_list[neg_idx].label);
|
||||||
float original = ((1<<pos_label)-1) * pos_loginv + ((1<<neg_label)-1) * neg_loginv;
|
float original =
|
||||||
float changed = ((1<<neg_label)-1) * pos_loginv + ((1<<pos_label)-1) * neg_loginv;
|
((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
|
||||||
float delta = (original-changed) * IDCG;
|
float changed =
|
||||||
|
((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
|
||||||
|
float delta = (original - changed) * IDCG;
|
||||||
if (delta < 0.0f) delta = - delta;
|
if (delta < 0.0f) delta = - delta;
|
||||||
pairs[i].weight = delta;
|
pairs[i].weight = delta;
|
||||||
}
|
}
|
||||||
@ -428,25 +428,31 @@ class LambdaRankObjNDCG : public LambdaRankObj {
|
|||||||
double sumdcg = 0.0;
|
double sumdcg = 0.0;
|
||||||
for (size_t i = 0; i < labels.size(); ++i) {
|
for (size_t i = 0; i < labels.size(); ++i) {
|
||||||
const unsigned rel = labels[i];
|
const unsigned rel = labels[i];
|
||||||
if (rel != 0) {
|
if (rel != 0) {
|
||||||
sumdcg += ((1<<rel)-1) / logf(i + 2);
|
sumdcg += ((1 << rel) - 1) / logf(i + 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return static_cast<float>(sumdcg);
|
return static_cast<float>(sumdcg);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class LambdaRankObjMAP : public LambdaRankObj {
|
class LambdaRankObjMAP : public LambdaRankObj {
|
||||||
public:
|
public:
|
||||||
virtual ~LambdaRankObjMAP(void) {}
|
virtual ~LambdaRankObjMAP(void) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
struct MAPStats {
|
struct MAPStats {
|
||||||
/* \brief the accumulated precision */
|
/*! \brief the accumulated precision */
|
||||||
float ap_acc;
|
float ap_acc;
|
||||||
/* \brief the accumulated precision assuming a positive instance is missing */
|
/*!
|
||||||
|
* \brief the accumulated precision,
|
||||||
|
* assuming a positive instance is missing
|
||||||
|
*/
|
||||||
float ap_acc_miss;
|
float ap_acc_miss;
|
||||||
/* \brief the accumulated precision assuming that one more positive instance is inserted ahead*/
|
/*!
|
||||||
|
* \brief the accumulated precision,
|
||||||
|
* assuming that one more positive instance is inserted ahead
|
||||||
|
*/
|
||||||
float ap_acc_add;
|
float ap_acc_add;
|
||||||
/* \brief the accumulated positive instance count */
|
/* \brief the accumulated positive instance count */
|
||||||
float hits;
|
float hits;
|
||||||
@ -454,7 +460,7 @@ class LambdaRankObjMAP : public LambdaRankObj {
|
|||||||
MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits)
|
MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits)
|
||||||
: ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {}
|
: ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {}
|
||||||
};
|
};
|
||||||
/*
|
/*!
|
||||||
* \brief Obtain the delta MAP if trying to switch the positions of instances in index1 or index2
|
* \brief Obtain the delta MAP if trying to switch the positions of instances in index1 or index2
|
||||||
* in sorted triples
|
* in sorted triples
|
||||||
* \param sorted_list the list containing entry information
|
* \param sorted_list the list containing entry information
|
||||||
@ -463,7 +469,8 @@ class LambdaRankObjMAP : public LambdaRankObj {
|
|||||||
*/
|
*/
|
||||||
inline float GetLambdaMAP(const std::vector<ListEntry> &sorted_list,
|
inline float GetLambdaMAP(const std::vector<ListEntry> &sorted_list,
|
||||||
int index1, int index2,
|
int index1, int index2,
|
||||||
std::vector<MAPStats> &map_stats){
|
std::vector<MAPStats> *p_map_stats) {
|
||||||
|
std::vector<MAPStats> &map_stats = *p_map_stats;
|
||||||
if (index1 == index2 || map_stats[map_stats.size() - 1].hits == 0) {
|
if (index1 == index2 || map_stats[map_stats.size() - 1].hits == 0) {
|
||||||
return 0.0f;
|
return 0.0f;
|
||||||
}
|
}
|
||||||
@ -482,18 +489,18 @@ class LambdaRankObjMAP : public LambdaRankObj {
|
|||||||
changed += map_stats[index2 - 1].ap_acc_miss - map_stats[index1].ap_acc_miss;
|
changed += map_stats[index2 - 1].ap_acc_miss - map_stats[index1].ap_acc_miss;
|
||||||
changed += map_stats[index2].hits / (index2 + 1);
|
changed += map_stats[index2].hits / (index2 + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
float ans = (changed - original) / (map_stats[map_stats.size() - 1].hits);
|
float ans = (changed - original) / (map_stats[map_stats.size() - 1].hits);
|
||||||
if (ans < 0) ans = -ans;
|
if (ans < 0) ans = -ans;
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
* \brief obtain preprocessing results for calculating delta MAP
|
* \brief obtain preprocessing results for calculating delta MAP
|
||||||
* \param sorted_list the list containing entry information
|
* \param sorted_list the list containing entry information
|
||||||
* \param map_stats a vector containing the accumulated precisions for each position in a list
|
* \param map_stats a vector containing the accumulated precisions for each position in a list
|
||||||
*/
|
*/
|
||||||
inline void GetMAPStats(const std::vector<ListEntry> &sorted_list,
|
inline void GetMAPStats(const std::vector<ListEntry> &sorted_list,
|
||||||
std::vector<MAPStats> &map_acc){
|
std::vector<MAPStats> *p_map_acc) {
|
||||||
|
std::vector<MAPStats> &map_acc = *p_map_acc;
|
||||||
map_acc.resize(sorted_list.size());
|
map_acc.resize(sorted_list.size());
|
||||||
float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0;
|
float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0;
|
||||||
for (size_t i = 1; i <= sorted_list.size(); ++i) {
|
for (size_t i = 1; i <= sorted_list.size(); ++i) {
|
||||||
@ -503,16 +510,18 @@ class LambdaRankObjMAP : public LambdaRankObj {
|
|||||||
acc2 += (hit - 1) / i;
|
acc2 += (hit - 1) / i;
|
||||||
acc3 += (hit + 1) / i;
|
acc3 += (hit + 1) / i;
|
||||||
}
|
}
|
||||||
map_acc[i - 1] = MAPStats(acc1,acc2,acc3,hit);
|
map_acc[i - 1] = MAPStats(acc1, acc2, acc3, hit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual void GetLambdaWeight(const std::vector<ListEntry> &sorted_list, std::vector<LambdaPair> *io_pairs) {
|
virtual void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
|
||||||
|
std::vector<LambdaPair> *io_pairs) {
|
||||||
std::vector<LambdaPair> &pairs = *io_pairs;
|
std::vector<LambdaPair> &pairs = *io_pairs;
|
||||||
std::vector<MAPStats> map_stats;
|
std::vector<MAPStats> map_stats;
|
||||||
GetMAPStats(sorted_list, map_stats);
|
GetMAPStats(sorted_list, &map_stats);
|
||||||
for (size_t i = 0; i < pairs.size(); ++i) {
|
for (size_t i = 0; i < pairs.size(); ++i) {
|
||||||
pairs[i].weight =
|
pairs[i].weight =
|
||||||
GetLambdaMAP(sorted_list, pairs[i].pos_index, pairs[i].neg_index, map_stats);
|
GetLambdaMAP(sorted_list, pairs[i].pos_index,
|
||||||
|
pairs[i].neg_index, &map_stats);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user