Take datatable as row major input. (#8472)

* Take datatable as row major input.

Try to avoid a transform with dense table.
This commit is contained in:
Jiaming Yuan
2022-11-24 09:20:13 +08:00
committed by GitHub
parent 284dcf8d22
commit e07245f110
4 changed files with 84 additions and 92 deletions

View File

@@ -473,16 +473,7 @@ class CSCAdapter : public detail::SingleBatchDataIter<CSCAdapterBatch> {
};
class DataTableAdapterBatch : public detail::NoMetaInfo {
public:
DataTableAdapterBatch(void** data, const char** feature_stypes,
size_t num_rows, size_t num_features)
: data_(data),
feature_stypes_(feature_stypes),
num_features_(num_features),
num_rows_(num_rows) {}
private:
enum class DTType : uint8_t {
enum class DTType : std::uint8_t {
kFloat32 = 0,
kFloat64 = 1,
kBool8 = 2,
@@ -493,7 +484,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
kUnknown = 7
};
DTType DTGetType(std::string type_string) const {
static DTType DTGetType(std::string type_string) {
if (type_string == "float32") {
return DTType::kFloat32;
} else if (type_string == "float64") {
@@ -514,8 +505,23 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
}
}
public:
DataTableAdapterBatch(void const* const* const data, char const* const* feature_stypes,
std::size_t num_rows, std::size_t num_features)
: data_(data), num_rows_(num_rows) {
CHECK(feature_types_.empty());
std::transform(feature_stypes, feature_stypes + num_features,
std::back_inserter(feature_types_),
[](char const* stype) { return DTGetType(stype); });
}
private:
class Line {
float DTGetValue(const void* column, DTType dt_type, size_t ridx) const {
std::size_t row_idx_;
void const* const* const data_;
std::vector<DTType> const& feature_types_;
float DTGetValue(void const* column, DTType dt_type, std::size_t ridx) const {
float missing = std::numeric_limits<float>::quiet_NaN();
switch (dt_type) {
case DTType::kFloat32: {
@@ -544,8 +550,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
}
case DTType::kInt64: {
int64_t val = reinterpret_cast<const int64_t*>(column)[ridx];
return val != -9223372036854775807 - 1 ? static_cast<float>(val)
: missing;
return val != -9223372036854775807 - 1 ? static_cast<float>(val) : missing;
}
default: {
LOG(FATAL) << "Unknown data table type.";
@@ -555,51 +560,41 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
}
public:
Line(DTType type, size_t size, size_t column_idx, const void* column)
: type_(type), size_(size), column_idx_(column_idx), column_(column) {}
size_t Size() const { return size_; }
COOTuple GetElement(size_t idx) const {
return COOTuple{idx, column_idx_, DTGetValue(column_, type_, idx)};
Line(std::size_t ridx, void const* const* const data, std::vector<DTType> const& ft)
: row_idx_{ridx}, data_{data}, feature_types_{ft} {}
std::size_t Size() const { return feature_types_.size(); }
COOTuple GetElement(std::size_t idx) const {
return COOTuple{row_idx_, idx, DTGetValue(data_[idx], feature_types_[idx], row_idx_)};
}
private:
DTType type_;
size_t size_;
size_t column_idx_;
const void* column_;
};
public:
size_t Size() const { return num_features_; }
const Line GetLine(size_t idx) const {
return Line(DTGetType(feature_stypes_[idx]), num_rows_, idx, data_[idx]);
}
static constexpr bool kIsRowMajor = false;
size_t Size() const { return num_rows_; }
const Line GetLine(std::size_t ridx) const { return {ridx, data_, feature_types_}; }
static constexpr bool kIsRowMajor = true;
private:
void** data_;
const char** feature_stypes_;
size_t num_features_;
size_t num_rows_;
void const* const* const data_;
std::vector<DTType> feature_types_;
std::size_t num_rows_;
};
class DataTableAdapter
: public detail::SingleBatchDataIter<DataTableAdapterBatch> {
class DataTableAdapter : public detail::SingleBatchDataIter<DataTableAdapterBatch> {
public:
DataTableAdapter(void** data, const char** feature_stypes, size_t num_rows,
size_t num_features)
DataTableAdapter(void** data, const char** feature_stypes, std::size_t num_rows,
std::size_t num_features)
: batch_(data, feature_stypes, num_rows, num_features),
num_rows_(num_rows),
num_columns_(num_features) {}
const DataTableAdapterBatch& Value() const override { return batch_; }
size_t NumRows() const { return num_rows_; }
size_t NumColumns() const { return num_columns_; }
std::size_t NumRows() const { return num_rows_; }
std::size_t NumColumns() const { return num_columns_; }
private:
DataTableAdapterBatch batch_;
size_t num_rows_;
size_t num_columns_;
std::size_t num_rows_;
std::size_t num_columns_;
};
class FileAdapterBatch {