Convert labels into tensor. (#7456)

* Add a new ctor to tensor for `initilizer_list`.
* Change labels from host device vector to tensor.
* Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
Jiaming Yuan
2021-12-17 00:58:35 +08:00
committed by GitHub
parent 6f8a4633b7
commit 5b1161bb64
35 changed files with 319 additions and 258 deletions

View File

@@ -127,14 +127,16 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
total_batch_size += batch.Size();
// Append meta information if available
if (batch.Labels() != nullptr) {
auto& labels = info_.labels_.HostVector();
labels.insert(labels.end(), batch.Labels(),
batch.Labels() + batch.Size());
info_.labels.ModifyInplace([&](auto* data, common::Span<size_t, 2> shape) {
shape[1] = 1;
auto& labels = data->HostVector();
labels.insert(labels.end(), batch.Labels(), batch.Labels() + batch.Size());
shape[0] += batch.Size();
});
}
if (batch.Weights() != nullptr) {
auto& weights = info_.weights_.HostVector();
weights.insert(weights.end(), batch.Weights(),
batch.Weights() + batch.Size());
weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size());
}
if (batch.BaseMargin() != nullptr) {
info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(),