@@ -158,18 +158,22 @@ struct EvalMClassBase : public Metric {
|
||||
bst_float Eval(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
bool distributed) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK(preds.Size() % info.labels_.Size() == 0)
|
||||
<< "label and prediction size not match";
|
||||
const size_t nclass = preds.Size() / info.labels_.Size();
|
||||
CHECK_GE(nclass, 1U)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
|
||||
int device = tparam_->gpu_id;
|
||||
auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds);
|
||||
double dat[2] { result.Residue(), result.Weights() };
|
||||
|
||||
if (info.labels_.Size() == 0) {
|
||||
CHECK_EQ(preds.Size(), 0);
|
||||
} else {
|
||||
CHECK(preds.Size() % info.labels_.Size() == 0) << "label and prediction size not match";
|
||||
}
|
||||
double dat[2] { 0.0, 0.0 };
|
||||
if (info.labels_.Size() != 0) {
|
||||
const size_t nclass = preds.Size() / info.labels_.Size();
|
||||
CHECK_GE(nclass, 1U)
|
||||
<< "mlogloss and merror are only used for multi-class classification,"
|
||||
<< " use logloss for binary classification";
|
||||
int device = tparam_->gpu_id;
|
||||
auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds);
|
||||
dat[0] = result.Residue();
|
||||
dat[1] = result.Weights();
|
||||
}
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user