Cudf support. (#4745)
* Initial support for cudf integration. * Add two C APIs for consuming data and metainfo. * Add CopyFrom for SimpleCSRSource as a generic function to consume the data. * Add FromDeviceColumnar for consuming device data. * Add new MetaInfo::SetInfo for consuming label, weight etc.
This commit is contained in:
committed by
Rory Mitchell
parent
ab357dd41c
commit
9700776597
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2014 by Contributors
|
||||
// Copyright (c) 2014-2019 by Contributors
|
||||
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/learner.h>
|
||||
@@ -16,7 +16,7 @@
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "./c_api_error.h"
|
||||
#include "c_api_error.h"
|
||||
#include "../data/simple_csr_source.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/io.h"
|
||||
@@ -189,6 +189,16 @@ int XGDMatrixCreateFromDataIter(
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterfaces(
|
||||
char const* c_json_strs, DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str {c_json_strs};
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
source->CopyFrom(json_str);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
|
||||
const unsigned* indices,
|
||||
const bst_float* data,
|
||||
@@ -679,9 +689,9 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const bst_float* info,
|
||||
xgboost::bst_ulong len) {
|
||||
const char* field,
|
||||
const bst_float* info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
@@ -689,10 +699,20 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
|
||||
char const* field,
|
||||
char const* interface_c_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, interface_c_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const unsigned* info,
|
||||
xgboost::bst_ulong len) {
|
||||
const char* field,
|
||||
const unsigned* info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
@@ -771,7 +791,7 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
*out = static_cast<size_t>(
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user