Fix categorical data with external memory. (#10433)

This commit is contained in:
Jiaming Yuan
2024-06-18 04:34:54 +08:00
committed by GitHub
parent a8ddbac163
commit b4cc350ec5
5 changed files with 31 additions and 7 deletions

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2017-2024 by XGBoost Contributors
* Copyright 2017-2024, XGBoost Contributors
* \file hist_util.h
* \brief Utility for fast histogram aggregation
* \author Philip Cho, Tianqi Chen
@@ -11,7 +11,6 @@
#include <cstdint> // for uint32_t
#include <limits>
#include <map>
#include <memory>
#include <utility>
#include <vector>

View File

@@ -4,7 +4,6 @@
*/
#include "gradient_index.h"
#include <algorithm>
#include <limits>
#include <memory>
#include <utility> // for forward
@@ -126,8 +125,8 @@ INSTANTIATION_PUSH(data::ColumnarAdapterBatch)
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
auto make_index = [this, n_index](auto t, common::BinTypeSize t_size) {
// Must resize instead of allocating a new one. This function is called everytime a
// new batch is pushed, and we grow the size accordingly without loosing the data the
// previous batches.
// new batch is pushed, and we grow the size accordingly without loosing the data in
// the previous batches.
using T = decltype(t);
std::size_t n_bytes = sizeof(T) * n_index;
CHECK_GE(n_bytes, this->data.size());

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021-2023, XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#ifndef XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
#define XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
@@ -23,6 +23,15 @@ inline bool ReadHistogramCuts(common::HistogramCuts *cuts, common::AlignedResour
if (!common::ReadVec(fi, &cuts->min_vals_.HostVector())) {
return false;
}
bool has_cat{false};
if (!fi->Read(&has_cat)) {
return false;
}
decltype(cuts->MaxCategory()) max_cat{0};
if (!fi->Read(&max_cat)) {
return false;
}
cuts->SetCategorical(has_cat, max_cat);
return true;
}
@@ -32,6 +41,8 @@ inline std::size_t WriteHistogramCuts(common::HistogramCuts const &cuts,
bytes += common::WriteVec(fo, cuts.Values());
bytes += common::WriteVec(fo, cuts.Ptrs());
bytes += common::WriteVec(fo, cuts.MinValues());
bytes += fo->Write(cuts.HasCategorical());
bytes += fo->Write(cuts.MaxCategory());
return bytes;
}
} // namespace xgboost::data