Convert labels into tensor. (#7456)
* Add a new ctor to tensor for `initilizer_list`. * Change labels from host device vector to tensor. * Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
@@ -127,14 +127,16 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
total_batch_size += batch.Size();
|
||||
// Append meta information if available
|
||||
if (batch.Labels() != nullptr) {
|
||||
auto& labels = info_.labels_.HostVector();
|
||||
labels.insert(labels.end(), batch.Labels(),
|
||||
batch.Labels() + batch.Size());
|
||||
info_.labels.ModifyInplace([&](auto* data, common::Span<size_t, 2> shape) {
|
||||
shape[1] = 1;
|
||||
auto& labels = data->HostVector();
|
||||
labels.insert(labels.end(), batch.Labels(), batch.Labels() + batch.Size());
|
||||
shape[0] += batch.Size();
|
||||
});
|
||||
}
|
||||
if (batch.Weights() != nullptr) {
|
||||
auto& weights = info_.weights_.HostVector();
|
||||
weights.insert(weights.end(), batch.Weights(),
|
||||
batch.Weights() + batch.Size());
|
||||
weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size());
|
||||
}
|
||||
if (batch.BaseMargin() != nullptr) {
|
||||
info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(),
|
||||
|
||||
Reference in New Issue
Block a user