Take datatable as row major input. (#8472)
* Take datatable as row major input. Try to avoid a transform with dense table.
This commit is contained in:
@@ -473,16 +473,7 @@ class CSCAdapter : public detail::SingleBatchDataIter<CSCAdapterBatch> {
|
||||
};
|
||||
|
||||
class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
DataTableAdapterBatch(void** data, const char** feature_stypes,
|
||||
size_t num_rows, size_t num_features)
|
||||
: data_(data),
|
||||
feature_stypes_(feature_stypes),
|
||||
num_features_(num_features),
|
||||
num_rows_(num_rows) {}
|
||||
|
||||
private:
|
||||
enum class DTType : uint8_t {
|
||||
enum class DTType : std::uint8_t {
|
||||
kFloat32 = 0,
|
||||
kFloat64 = 1,
|
||||
kBool8 = 2,
|
||||
@@ -493,7 +484,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
kUnknown = 7
|
||||
};
|
||||
|
||||
DTType DTGetType(std::string type_string) const {
|
||||
static DTType DTGetType(std::string type_string) {
|
||||
if (type_string == "float32") {
|
||||
return DTType::kFloat32;
|
||||
} else if (type_string == "float64") {
|
||||
@@ -514,8 +505,23 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
DataTableAdapterBatch(void const* const* const data, char const* const* feature_stypes,
|
||||
std::size_t num_rows, std::size_t num_features)
|
||||
: data_(data), num_rows_(num_rows) {
|
||||
CHECK(feature_types_.empty());
|
||||
std::transform(feature_stypes, feature_stypes + num_features,
|
||||
std::back_inserter(feature_types_),
|
||||
[](char const* stype) { return DTGetType(stype); });
|
||||
}
|
||||
|
||||
private:
|
||||
class Line {
|
||||
float DTGetValue(const void* column, DTType dt_type, size_t ridx) const {
|
||||
std::size_t row_idx_;
|
||||
void const* const* const data_;
|
||||
std::vector<DTType> const& feature_types_;
|
||||
|
||||
float DTGetValue(void const* column, DTType dt_type, std::size_t ridx) const {
|
||||
float missing = std::numeric_limits<float>::quiet_NaN();
|
||||
switch (dt_type) {
|
||||
case DTType::kFloat32: {
|
||||
@@ -544,8 +550,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
case DTType::kInt64: {
|
||||
int64_t val = reinterpret_cast<const int64_t*>(column)[ridx];
|
||||
return val != -9223372036854775807 - 1 ? static_cast<float>(val)
|
||||
: missing;
|
||||
return val != -9223372036854775807 - 1 ? static_cast<float>(val) : missing;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unknown data table type.";
|
||||
@@ -555,51 +560,41 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
|
||||
public:
|
||||
Line(DTType type, size_t size, size_t column_idx, const void* column)
|
||||
: type_(type), size_(size), column_idx_(column_idx), column_(column) {}
|
||||
|
||||
size_t Size() const { return size_; }
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return COOTuple{idx, column_idx_, DTGetValue(column_, type_, idx)};
|
||||
Line(std::size_t ridx, void const* const* const data, std::vector<DTType> const& ft)
|
||||
: row_idx_{ridx}, data_{data}, feature_types_{ft} {}
|
||||
std::size_t Size() const { return feature_types_.size(); }
|
||||
COOTuple GetElement(std::size_t idx) const {
|
||||
return COOTuple{row_idx_, idx, DTGetValue(data_[idx], feature_types_[idx], row_idx_)};
|
||||
}
|
||||
|
||||
private:
|
||||
DTType type_;
|
||||
size_t size_;
|
||||
size_t column_idx_;
|
||||
const void* column_;
|
||||
};
|
||||
|
||||
public:
|
||||
size_t Size() const { return num_features_; }
|
||||
const Line GetLine(size_t idx) const {
|
||||
return Line(DTGetType(feature_stypes_[idx]), num_rows_, idx, data_[idx]);
|
||||
}
|
||||
static constexpr bool kIsRowMajor = false;
|
||||
size_t Size() const { return num_rows_; }
|
||||
const Line GetLine(std::size_t ridx) const { return {ridx, data_, feature_types_}; }
|
||||
static constexpr bool kIsRowMajor = true;
|
||||
|
||||
private:
|
||||
void** data_;
|
||||
const char** feature_stypes_;
|
||||
size_t num_features_;
|
||||
size_t num_rows_;
|
||||
void const* const* const data_;
|
||||
|
||||
std::vector<DTType> feature_types_;
|
||||
std::size_t num_rows_;
|
||||
};
|
||||
|
||||
class DataTableAdapter
|
||||
: public detail::SingleBatchDataIter<DataTableAdapterBatch> {
|
||||
class DataTableAdapter : public detail::SingleBatchDataIter<DataTableAdapterBatch> {
|
||||
public:
|
||||
DataTableAdapter(void** data, const char** feature_stypes, size_t num_rows,
|
||||
size_t num_features)
|
||||
DataTableAdapter(void** data, const char** feature_stypes, std::size_t num_rows,
|
||||
std::size_t num_features)
|
||||
: batch_(data, feature_stypes, num_rows, num_features),
|
||||
num_rows_(num_rows),
|
||||
num_columns_(num_features) {}
|
||||
const DataTableAdapterBatch& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return num_rows_; }
|
||||
size_t NumColumns() const { return num_columns_; }
|
||||
std::size_t NumRows() const { return num_rows_; }
|
||||
std::size_t NumColumns() const { return num_columns_; }
|
||||
|
||||
private:
|
||||
DataTableAdapterBatch batch_;
|
||||
size_t num_rows_;
|
||||
size_t num_columns_;
|
||||
std::size_t num_rows_;
|
||||
std::size_t num_columns_;
|
||||
};
|
||||
|
||||
class FileAdapterBatch {
|
||||
|
||||
Reference in New Issue
Block a user