Convert labels into tensor. (#7456)
* Add a new ctor to tensor for `initilizer_list`. * Change labels from host device vector to tensor. * Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
@@ -21,10 +21,10 @@ TEST(Metric, DeclareUnifiedTest(BinaryAUC)) {
|
||||
|
||||
// Invalid dataset
|
||||
MetaInfo info;
|
||||
info.labels_ = {0, 0};
|
||||
info.labels = linalg::Tensor<float, 2>{{0.0f, 0.0f}, {2}, -1};
|
||||
float auc = metric->Eval({1, 1}, info, false);
|
||||
ASSERT_TRUE(std::isnan(auc));
|
||||
info.labels_ = HostDeviceVector<float>{};
|
||||
*info.labels.Data() = HostDeviceVector<float>{};
|
||||
auc = metric->Eval(HostDeviceVector<float>{}, info, false);
|
||||
ASSERT_TRUE(std::isnan(auc));
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
|
||||
|
||||
HostDeviceVector<float> predts;
|
||||
MetaInfo info;
|
||||
auto &h_labels = info.labels_.HostVector();
|
||||
auto &h_labels = info.labels.Data()->HostVector();
|
||||
auto &h_predts = predts.HostVector();
|
||||
|
||||
SimpleLCG lcg;
|
||||
|
||||
@@ -11,13 +11,14 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device)
|
||||
|
||||
HostDeviceVector<float> predts;
|
||||
MetaInfo info;
|
||||
auto &h_labels = info.labels_.HostVector();
|
||||
auto &h_predts = predts.HostVector();
|
||||
|
||||
SimpleLCG lcg;
|
||||
|
||||
size_t n_samples = 2048, n_classes = 4;
|
||||
h_labels.resize(n_samples);
|
||||
|
||||
info.labels.Reshape(n_samples);
|
||||
auto &h_labels = info.labels.Data()->HostVector();
|
||||
h_predts.resize(n_samples * n_classes);
|
||||
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user