Allow import via python datatable. (#3272)
* Allow import via python datatable. * Write unit tests * Refactor dt API functions * Refactor python code * Lint fixes * Address review comments
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
eecf341ea7
commit
9ac163d0bb
@@ -18,6 +18,7 @@
|
||||
#include "../common/io.h"
|
||||
#include "../common/group_data.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
// booster wrapper for backward compatible reason.
|
||||
class Booster {
|
||||
@@ -439,6 +440,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
|
||||
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
|
||||
// const int nthreadmax = omp_get_max_threads();
|
||||
if (nthread <= 0) nthread=nthreadmax;
|
||||
int nthread_orig = omp_get_max_threads();
|
||||
omp_set_num_threads(nthread);
|
||||
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
@@ -497,12 +499,150 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
|
||||
}
|
||||
}
|
||||
}
|
||||
// restore omp state
|
||||
omp_set_num_threads(nthread_orig);
|
||||
|
||||
mat.info.num_nonzero_ = mat.page_.data.size();
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
enum class DTType : uint8_t {
|
||||
kFloat32 = 0,
|
||||
kFloat64 = 1,
|
||||
kBool8 = 2,
|
||||
kInt32 = 3,
|
||||
kInt8 = 4,
|
||||
kInt16 = 5,
|
||||
kInt64 = 6,
|
||||
kUnknown = 7
|
||||
};
|
||||
|
||||
DTType DTGetType(std::string type_string) {
|
||||
if (type_string == "float32") {
|
||||
return DTType::kFloat32;
|
||||
} else if (type_string == "float64") {
|
||||
return DTType::kFloat64;
|
||||
} else if (type_string == "bool8") {
|
||||
return DTType::kBool8;
|
||||
} else if (type_string == "int32") {
|
||||
return DTType::kInt32;
|
||||
} else if (type_string == "int8") {
|
||||
return DTType::kInt8;
|
||||
} else if (type_string == "int16") {
|
||||
return DTType::kInt16;
|
||||
} else if (type_string == "int64") {
|
||||
return DTType::kInt64;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown data table type.";
|
||||
return DTType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
float DTGetValue(void* column, DTType dt_type, size_t ridx) {
|
||||
float missing = std::numeric_limits<float>::quiet_NaN();
|
||||
switch (dt_type) {
|
||||
case DTType::kFloat32: {
|
||||
float val = reinterpret_cast<float*>(column)[ridx];
|
||||
return std::isfinite(val) ? val : missing;
|
||||
}
|
||||
case DTType::kFloat64: {
|
||||
double val = reinterpret_cast<double*>(column)[ridx];
|
||||
return std::isfinite(val) ? static_cast<float>(val) : missing;
|
||||
}
|
||||
case DTType::kBool8: {
|
||||
bool val = reinterpret_cast<bool*>(column)[ridx];
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
case DTType::kInt32: {
|
||||
int32_t val = reinterpret_cast<int32_t*>(column)[ridx];
|
||||
return val != (-2147483647 - 1) ? static_cast<float>(val) : missing;
|
||||
}
|
||||
case DTType::kInt8: {
|
||||
int8_t val = reinterpret_cast<int8_t*>(column)[ridx];
|
||||
return val != -128 ? static_cast<float>(val) : missing;
|
||||
}
|
||||
case DTType::kInt16: {
|
||||
int16_t val = reinterpret_cast<int16_t*>(column)[ridx];
|
||||
return val != -32768 ? static_cast<float>(val) : missing;
|
||||
}
|
||||
case DTType::kInt64: {
|
||||
int64_t val = reinterpret_cast<int64_t*>(column)[ridx];
|
||||
return val != -9223372036854775807 - 1 ? static_cast<float>(val)
|
||||
: missing;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unknown data table type.";
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
|
||||
xgboost::bst_ulong nrow,
|
||||
xgboost::bst_ulong ncol, DMatrixHandle* out,
|
||||
int nthread) {
|
||||
// avoid openmp unless enough data to be worth it to avoid overhead costs
|
||||
if (nrow * ncol <= 10000 * 50) {
|
||||
nthread = 1;
|
||||
}
|
||||
|
||||
API_BEGIN();
|
||||
const int nthreadmax = std::max(omp_get_num_procs() / 2 - 1, 1);
|
||||
if (nthread <= 0) nthread = nthreadmax;
|
||||
int nthread_orig = omp_get_max_threads();
|
||||
omp_set_num_threads(nthread);
|
||||
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
data::SimpleCSRSource& mat = *source;
|
||||
mat.page_.offset.resize(1 + nrow);
|
||||
mat.info.num_row_ = nrow;
|
||||
mat.info.num_col_ = ncol;
|
||||
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
// Count elements per row, column by column
|
||||
for (auto j = 0; j < ncol; ++j) {
|
||||
DTType dtype = DTGetType(feature_stypes[j]);
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < nrow; ++i) {
|
||||
float val = DTGetValue(data[j], dtype, i);
|
||||
if (!std::isnan(val)) {
|
||||
mat.page_.offset[i + 1]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// do cumulative sum (to avoid otherwise need to copy)
|
||||
PrefixSum(&mat.page_.offset[0], mat.page_.offset.size());
|
||||
|
||||
mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back());
|
||||
|
||||
// Fill data matrix (now that know size, no need for slow push_back())
|
||||
std::vector<size_t> position(nrow);
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
for (xgboost::bst_ulong j = 0; j < ncol; ++j) {
|
||||
DTType dtype = DTGetType(feature_stypes[j]);
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < nrow; ++i) {
|
||||
float val = DTGetValue(data[j], dtype, i);
|
||||
if (!std::isnan(val)) {
|
||||
mat.page_.data[mat.page_.offset[i] + position[i]] = Entry(j, val);
|
||||
position[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// restore omp state
|
||||
omp_set_num_threads(nthread_orig);
|
||||
|
||||
mat.info.num_nonzero_ = mat.page_.data.size();
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
const int* idxset,
|
||||
xgboost::bst_ulong len,
|
||||
|
||||
Reference in New Issue
Block a user