Device dmatrix (#5420)

This commit is contained in:
Rory Mitchell
2020-03-28 14:42:21 +13:00
committed by GitHub
parent 780de49ddb
commit 13b10a6370
24 changed files with 915 additions and 310 deletions

View File

@@ -4,6 +4,7 @@
#include "xgboost/learner.h"
#include "c_api_error.h"
#include "../data/device_adapter.cuh"
#include "../data/device_dmatrix.h"
using namespace xgboost; // NOLINT
@@ -29,3 +30,25 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
API_END();
}
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
bst_float missing, int nthread, int max_bin,
DMatrixHandle* out) {
API_BEGIN();
std::string json_str{c_json_strs};
data::CudfAdapter adapter(json_str);
*out =
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
API_END();
}
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_strs,
bst_float missing, int nthread, int max_bin,
DMatrixHandle* out) {
API_BEGIN();
std::string json_str{c_json_strs};
data::CupyAdapter adapter(json_str);
*out =
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
API_END();
}