Support dataframe data format in native XGBoost. (#9828)
- Implement a columnar adapter. - Refactor Python pandas handling code to avoid converting into a single numpy array. - Add support in R for transforming columns. - Support R data.frame and factor type.
This commit is contained in:
parent
b3700bbb3f
commit
faf0f2df10
@ -19,7 +19,8 @@
|
||||
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
||||
#' It is useful when a 0 or some other extreme value represents missing values in data.
|
||||
#' @param silent whether to suppress printing an informational message after loading from a file.
|
||||
#' @param feature_names Set names for features.
|
||||
#' @param feature_names Set names for features. Overrides column names in data
|
||||
#' frame and matrix.
|
||||
#' @param nthread Number of threads used for creating DMatrix.
|
||||
#' @param group Group size for all ranking group.
|
||||
#' @param qid Query ID for data samples, used for ranking.
|
||||
@ -32,6 +33,8 @@
|
||||
#' If a DMatrix gets serialized and then de-serialized (for example, when saving data in an R session or caching
|
||||
#' chunks in an Rmd file), the resulting object will not be usable anymore and will need to be reconstructed
|
||||
#' from the original source of data.
|
||||
#' @param enable_categorical Experimental support of specializing for
|
||||
#' categorical features. JSON/UBJSON serialization format is required.
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
@ -58,19 +61,26 @@ xgb.DMatrix <- function(
|
||||
qid = NULL,
|
||||
label_lower_bound = NULL,
|
||||
label_upper_bound = NULL,
|
||||
feature_weights = NULL
|
||||
feature_weights = NULL,
|
||||
enable_categorical = FALSE
|
||||
) {
|
||||
if (!is.null(group) && !is.null(qid)) {
|
||||
stop("Either one of 'group' or 'qid' should be NULL")
|
||||
}
|
||||
ctypes <- NULL
|
||||
if (typeof(data) == "character") {
|
||||
if (length(data) > 1)
|
||||
stop("'data' has class 'character' and length ", length(data),
|
||||
".\n 'data' accepts either a numeric matrix or a single filename.")
|
||||
if (length(data) > 1) {
|
||||
stop(
|
||||
"'data' has class 'character' and length ", length(data),
|
||||
".\n 'data' accepts either a numeric matrix or a single filename."
|
||||
)
|
||||
}
|
||||
data <- path.expand(data)
|
||||
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
|
||||
} else if (is.matrix(data)) {
|
||||
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
|
||||
handle <- .Call(
|
||||
XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1))
|
||||
)
|
||||
} else if (inherits(data, "dgCMatrix")) {
|
||||
handle <- .Call(
|
||||
XGDMatrixCreateFromCSC_R,
|
||||
@ -103,6 +113,39 @@ xgb.DMatrix <- function(
|
||||
missing,
|
||||
as.integer(NVL(nthread, -1))
|
||||
)
|
||||
} else if (is.data.frame(data)) {
|
||||
ctypes <- sapply(data, function(x) {
|
||||
if (is.factor(x)) {
|
||||
if (!enable_categorical) {
|
||||
stop(
|
||||
"When factor type is used, the parameter `enable_categorical`",
|
||||
" must be set to TRUE."
|
||||
)
|
||||
}
|
||||
"c"
|
||||
} else if (is.integer(x)) {
|
||||
"int"
|
||||
} else if (is.logical(x)) {
|
||||
"i"
|
||||
} else {
|
||||
if (!is.numeric(x)) {
|
||||
stop("Invalid type in dataframe.")
|
||||
}
|
||||
"float"
|
||||
}
|
||||
})
|
||||
## as.data.frame somehow converts integer/logical into real.
|
||||
data <- as.data.frame(sapply(data, function(x) {
|
||||
if (is.factor(x)) {
|
||||
## XGBoost uses 0-based indexing.
|
||||
as.numeric(x) - 1
|
||||
} else {
|
||||
x
|
||||
}
|
||||
}))
|
||||
handle <- .Call(
|
||||
XGDMatrixCreateFromDF_R, data, missing, as.integer(NVL(nthread, -1))
|
||||
)
|
||||
} else {
|
||||
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
||||
}
|
||||
@ -137,6 +180,9 @@ xgb.DMatrix <- function(
|
||||
if (!is.null(feature_weights)) {
|
||||
setinfo(dmat, "feature_weights", feature_weights)
|
||||
}
|
||||
if (!is.null(ctypes)) {
|
||||
setinfo(dmat, "feature_type", ctypes)
|
||||
}
|
||||
|
||||
return(dmat)
|
||||
}
|
||||
|
||||
@ -17,7 +17,8 @@ xgb.DMatrix(
|
||||
qid = NULL,
|
||||
label_lower_bound = NULL,
|
||||
label_upper_bound = NULL,
|
||||
feature_weights = NULL
|
||||
feature_weights = NULL,
|
||||
enable_categorical = FALSE
|
||||
)
|
||||
}
|
||||
\arguments{
|
||||
@ -42,7 +43,8 @@ It is useful when a 0 or some other extreme value represents missing values in d
|
||||
|
||||
\item{silent}{whether to suppress printing an informational message after loading from a file.}
|
||||
|
||||
\item{feature_names}{Set names for features.}
|
||||
\item{feature_names}{Set names for features. Overrides column names in data
|
||||
frame and matrix.}
|
||||
|
||||
\item{nthread}{Number of threads used for creating DMatrix.}
|
||||
|
||||
@ -55,6 +57,9 @@ It is useful when a 0 or some other extreme value represents missing values in d
|
||||
\item{label_upper_bound}{Upper bound for survival training.}
|
||||
|
||||
\item{feature_weights}{Set feature weights for column sampling.}
|
||||
|
||||
\item{enable_categorical}{Experimental support of specializing for
|
||||
categorical features. JSON/UBJSON serialization format is required.}
|
||||
}
|
||||
\description{
|
||||
Construct xgb.DMatrix object from either a dense matrix, a sparse matrix, or a local file.
|
||||
|
||||
@ -41,6 +41,7 @@ extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetFloatInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetUIntInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromDF_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixNumCol_R(SEXP);
|
||||
extern SEXP XGDMatrixNumRow_R(SEXP);
|
||||
@ -79,6 +80,7 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||
{"XGDMatrixGetFloatInfo_R", (DL_FUNC) &XGDMatrixGetFloatInfo_R, 2},
|
||||
{"XGDMatrixGetUIntInfo_R", (DL_FUNC) &XGDMatrixGetUIntInfo_R, 2},
|
||||
{"XGDMatrixCreateFromDF_R", (DL_FUNC) &XGDMatrixCreateFromDF_R, 3},
|
||||
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
|
||||
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
||||
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
||||
|
||||
@ -223,6 +223,69 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixCreateFromDF_R(SEXP df, SEXP missing, SEXP n_threads) {
|
||||
SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
||||
R_API_BEGIN();
|
||||
|
||||
DMatrixHandle handle;
|
||||
|
||||
auto make_vec = [&](auto const *ptr, std::int32_t len) {
|
||||
auto v = xgboost::linalg::MakeVec(ptr, len);
|
||||
return xgboost::linalg::ArrayInterface(v);
|
||||
};
|
||||
|
||||
std::int32_t rc{0};
|
||||
{
|
||||
using xgboost::Json;
|
||||
auto n_features = Rf_xlength(df);
|
||||
std::vector<Json> array(n_features);
|
||||
CHECK_GT(n_features, 0);
|
||||
auto len = Rf_xlength(VECTOR_ELT(df, 0));
|
||||
// The `data.frame` in R actually converts all data into numeric. The other type
|
||||
// handlers here are not used. At the moment they are kept as a reference for when we
|
||||
// can avoid making data copies during transformation.
|
||||
for (decltype(n_features) i = 0; i < n_features; ++i) {
|
||||
switch (TYPEOF(VECTOR_ELT(df, i))) {
|
||||
case INTSXP: {
|
||||
auto const *ptr = INTEGER(VECTOR_ELT(df, i));
|
||||
array[i] = make_vec(ptr, len);
|
||||
break;
|
||||
}
|
||||
case REALSXP: {
|
||||
auto const *ptr = REAL(VECTOR_ELT(df, i));
|
||||
array[i] = make_vec(ptr, len);
|
||||
break;
|
||||
}
|
||||
case LGLSXP: {
|
||||
auto const *ptr = LOGICAL(VECTOR_ELT(df, i));
|
||||
array[i] = make_vec(ptr, len);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "data.frame has unsupported type.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Json jinterface{std::move(array)};
|
||||
auto sinterface = Json::Dump(jinterface);
|
||||
Json jconfig{xgboost::Object{}};
|
||||
jconfig["missing"] = asReal(missing);
|
||||
jconfig["nthread"] = asInteger(n_threads);
|
||||
auto sconfig = Json::Dump(jconfig);
|
||||
|
||||
rc = XGDMatrixCreateFromColumnar(sinterface.c_str(), sconfig.c_str(), &handle);
|
||||
}
|
||||
|
||||
CHECK_CALL(rc);
|
||||
R_SetExternalPtrAddr(ret, handle);
|
||||
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
||||
R_API_END();
|
||||
Rf_unprotect(1);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
namespace {
|
||||
void CreateFromSparse(SEXP indptr, SEXP indices, SEXP data, std::string *indptr_str,
|
||||
std::string *indices_str, std::string *data_str) {
|
||||
@ -298,6 +361,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
||||
res_code = XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol,
|
||||
config.c_str(), &handle);
|
||||
}
|
||||
CHECK_CALL(res_code);
|
||||
R_SetExternalPtrAddr(ret, handle);
|
||||
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
||||
R_API_END();
|
||||
|
||||
@ -53,6 +53,16 @@ XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent);
|
||||
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat,
|
||||
SEXP missing,
|
||||
SEXP n_threads);
|
||||
|
||||
/**
|
||||
* @brief Create matrix content from a data frame.
|
||||
* @param data R data.frame object
|
||||
* @param missing which value to represent missing value
|
||||
* @param n_threads Number of threads used to construct DMatrix from dense matrix.
|
||||
* @return created dmatrix
|
||||
*/
|
||||
XGB_DLL SEXP XGDMatrixCreateFromDF_R(SEXP df, SEXP missing, SEXP n_threads);
|
||||
|
||||
/*!
|
||||
* \brief create a matrix content from CSC format
|
||||
* \param indptr pointer to column headers
|
||||
|
||||
@ -322,3 +322,30 @@ test_that("xgb.DMatrix: can get group for both 'qid' and 'group' constructors",
|
||||
expected_gr <- c(0, 20, 40, 100)
|
||||
expect_equal(info_gr, expected_gr)
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: data.frame", {
|
||||
df <- data.frame(
|
||||
a = (1:4) / 10,
|
||||
num = c(1, NA, 3, 4),
|
||||
as.int = as.integer(c(1, 2, 3, 4)),
|
||||
lo = c(TRUE, FALSE, NA, TRUE),
|
||||
str.fac = c("a", "b", "d", "c"),
|
||||
as.fac = as.factor(c(3, 5, 8, 11)),
|
||||
stringsAsFactors = TRUE
|
||||
)
|
||||
|
||||
m <- xgb.DMatrix(df, enable_categorical = TRUE)
|
||||
expect_equal(colnames(m), colnames(df))
|
||||
expect_equal(
|
||||
getinfo(m, "feature_type"), c("float", "float", "int", "i", "c", "c")
|
||||
)
|
||||
expect_error(xgb.DMatrix(df))
|
||||
|
||||
df <- data.frame(
|
||||
missing = c("a", "b", "d", NA),
|
||||
valid = c("a", "b", "d", "c"),
|
||||
stringsAsFactors = TRUE
|
||||
)
|
||||
m <- xgb.DMatrix(df, enable_categorical = TRUE)
|
||||
expect_equal(getinfo(m, "feature_type"), c("c", "c"))
|
||||
})
|
||||
|
||||
@ -78,6 +78,10 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, random_state=1994, test_size=0.2
|
||||
)
|
||||
# Be aware that the encoding for X_train and X_test are the same here. In practice,
|
||||
# we should try to use an encoder like (sklearn OrdinalEncoder) to obtain the
|
||||
# categorical values.
|
||||
|
||||
# Specify `enable_categorical` to True.
|
||||
clf = xgb.XGBClassifier(
|
||||
**params,
|
||||
|
||||
@ -159,6 +159,16 @@ XGB_DLL int XGDMatrixCreateFromURI(char const *config, DMatrixHandle *out);
|
||||
XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indices,
|
||||
const float *data, size_t nindptr, size_t nelem,
|
||||
size_t num_col, DMatrixHandle *out);
|
||||
/**
|
||||
* @brief Create a DMatrix from columnar data. (table)
|
||||
*
|
||||
* @param data See @ref XGBoosterPredictFromColumnar for details.
|
||||
* @param config See @ref XGDMatrixCreateFromDense for details.
|
||||
* @param out The created dmatrix.
|
||||
*
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixCreateFromColumnar(char const *data, char const *config, DMatrixHandle *out);
|
||||
|
||||
/**
|
||||
* @example c-api-demo.c
|
||||
@ -514,6 +524,16 @@ XGB_DLL int
|
||||
XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle,
|
||||
const char *c_interface_str);
|
||||
|
||||
/**
|
||||
* @brief Set columnar (table) data on a DMatrix proxy.
|
||||
*
|
||||
* @param handle A DMatrix proxy created by @ref XGProxyDMatrixCreate
|
||||
* @param c_interface_str See @ref XGBoosterPredictFromColumnar for details.
|
||||
*
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle, char const *c_interface_str);
|
||||
|
||||
/*!
|
||||
* \brief Set data on a DMatrix proxy.
|
||||
*
|
||||
@ -1113,6 +1133,31 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
|
||||
* @example inference.c
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Inplace prediction from CPU columnar data. (Table)
|
||||
*
|
||||
* @note If the booster is configured to run on a CUDA device, XGBoost falls back to run
|
||||
* prediction with DMatrix with a performance warning.
|
||||
*
|
||||
* @param handle Booster handle.
|
||||
* @param values An JSON array of __array_interface__ for each column.
|
||||
* @param config See @ref XGBoosterPredictFromDMatrix for more info.
|
||||
* Additional fields for inplace prediction are:
|
||||
* - "missing": float
|
||||
* @param m An optional (NULL if not available) proxy DMatrix instance
|
||||
* storing meta info.
|
||||
*
|
||||
* @param out_shape See @ref XGBoosterPredictFromDMatrix for more info.
|
||||
* @param out_dim See @ref XGBoosterPredictFromDMatrix for more info.
|
||||
* @param out_result See @ref XGBoosterPredictFromDMatrix for more info.
|
||||
*
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
bst_ulong const **out_shape, bst_ulong *out_dim,
|
||||
const float **out_result);
|
||||
|
||||
/**
|
||||
* \brief Inplace prediction from CPU CSR matrix.
|
||||
*
|
||||
|
||||
@ -822,8 +822,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
|
||||
.. note:: This parameter is experimental
|
||||
|
||||
Experimental support of specializing for categorical features. Do not set
|
||||
to True unless you are interested in development. Also, JSON/UBJSON
|
||||
Experimental support of specializing for categorical features. JSON/UBJSON
|
||||
serialization format is required.
|
||||
|
||||
"""
|
||||
@ -1431,6 +1430,12 @@ class _ProxyDMatrix(DMatrix):
|
||||
_LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data))
|
||||
)
|
||||
|
||||
def _set_data_from_pandas(self, data: DataType) -> None:
|
||||
"""Set data from a pandas DataFrame. The input is a PandasTransformed instance."""
|
||||
_check_call(
|
||||
_LIB.XGProxyDMatrixSetDataColumnar(self.handle, data.array_interface())
|
||||
)
|
||||
|
||||
def _set_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
|
||||
"""Set data from scipy csr"""
|
||||
from .data import _array_interface
|
||||
@ -2440,6 +2445,7 @@ class Booster:
|
||||
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
|
||||
|
||||
from .data import (
|
||||
PandasTransformed,
|
||||
_array_interface,
|
||||
_arrow_transform,
|
||||
_is_arrow,
|
||||
@ -2494,6 +2500,19 @@ class Booster:
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, False)
|
||||
if isinstance(data, PandasTransformed):
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromColumnar(
|
||||
self.handle,
|
||||
data.array_interface(),
|
||||
args,
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
ctypes.byref(preds),
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, False)
|
||||
if isinstance(data, scipy.sparse.csr_matrix):
|
||||
from .data import transform_scipy_sparse
|
||||
|
||||
|
||||
@ -65,13 +65,18 @@ def _is_scipy_csr(data: DataType) -> bool:
|
||||
return isinstance(data, scipy.sparse.csr_matrix)
|
||||
|
||||
|
||||
def _array_interface(data: np.ndarray) -> bytes:
|
||||
def _array_interface_dict(data: np.ndarray) -> dict:
|
||||
assert (
|
||||
data.dtype.hasobject is False
|
||||
), "Input data contains `object` dtype. Expecting numeric data."
|
||||
interface = data.__array_interface__
|
||||
if "mask" in interface:
|
||||
interface["mask"] = interface["mask"].__array_interface__
|
||||
return interface
|
||||
|
||||
|
||||
def _array_interface(data: np.ndarray) -> bytes:
|
||||
interface = _array_interface_dict(data)
|
||||
interface_str = bytes(json.dumps(interface), "utf-8")
|
||||
return interface_str
|
||||
|
||||
@ -265,24 +270,24 @@ pandas_nullable_mapper = {
|
||||
"Int16": "int",
|
||||
"Int32": "int",
|
||||
"Int64": "int",
|
||||
"UInt8": "i",
|
||||
"UInt16": "i",
|
||||
"UInt32": "i",
|
||||
"UInt64": "i",
|
||||
"UInt8": "int",
|
||||
"UInt16": "int",
|
||||
"UInt32": "int",
|
||||
"UInt64": "int",
|
||||
"Float32": "float",
|
||||
"Float64": "float",
|
||||
"boolean": "i",
|
||||
}
|
||||
|
||||
pandas_pyarrow_mapper = {
|
||||
"int8[pyarrow]": "i",
|
||||
"int16[pyarrow]": "i",
|
||||
"int32[pyarrow]": "i",
|
||||
"int64[pyarrow]": "i",
|
||||
"uint8[pyarrow]": "i",
|
||||
"uint16[pyarrow]": "i",
|
||||
"uint32[pyarrow]": "i",
|
||||
"uint64[pyarrow]": "i",
|
||||
"int8[pyarrow]": "int",
|
||||
"int16[pyarrow]": "int",
|
||||
"int32[pyarrow]": "int",
|
||||
"int64[pyarrow]": "int",
|
||||
"uint8[pyarrow]": "int",
|
||||
"uint16[pyarrow]": "int",
|
||||
"uint32[pyarrow]": "int",
|
||||
"uint64[pyarrow]": "int",
|
||||
"float[pyarrow]": "float",
|
||||
"float32[pyarrow]": "float",
|
||||
"double[pyarrow]": "float",
|
||||
@ -295,7 +300,7 @@ _pandas_dtype_mapper.update(pandas_pyarrow_mapper)
|
||||
|
||||
|
||||
_ENABLE_CAT_ERR = (
|
||||
"When categorical type is supplied, The experimental DMatrix parameter"
|
||||
"When categorical type is supplied, the experimental DMatrix parameter"
|
||||
"`enable_categorical` must be set to `True`."
|
||||
)
|
||||
|
||||
@ -407,89 +412,122 @@ def is_pd_sparse_dtype(dtype: PandasDType) -> bool:
|
||||
return is_sparse(dtype)
|
||||
|
||||
|
||||
def pandas_cat_null(data: DataFrame) -> DataFrame:
|
||||
"""Handle categorical dtype and nullable extension types from pandas."""
|
||||
import pandas as pd
|
||||
|
||||
# handle category codes and nullable.
|
||||
cat_columns = []
|
||||
nul_columns = []
|
||||
# avoid an unnecessary conversion if possible
|
||||
for col, dtype in zip(data.columns, data.dtypes):
|
||||
if is_pd_cat_dtype(dtype):
|
||||
cat_columns.append(col)
|
||||
elif is_pa_ext_categorical_dtype(dtype):
|
||||
raise ValueError(
|
||||
"pyarrow dictionary type is not supported. Use pandas category instead."
|
||||
)
|
||||
elif is_nullable_dtype(dtype):
|
||||
nul_columns.append(col)
|
||||
|
||||
if cat_columns or nul_columns:
|
||||
# Avoid transformation due to: PerformanceWarning: DataFrame is highly
|
||||
# fragmented
|
||||
transformed = data.copy(deep=False)
|
||||
else:
|
||||
transformed = data
|
||||
|
||||
def cat_codes(ser: pd.Series) -> pd.Series:
|
||||
if is_pd_cat_dtype(ser.dtype):
|
||||
return ser.cat.codes
|
||||
assert is_pa_ext_categorical_dtype(ser.dtype)
|
||||
# Not yet supported, the index is not ordered for some reason. Alternately:
|
||||
# `combine_chunks().to_pandas().cat.codes`. The result is the same.
|
||||
return ser.array.__arrow_array__().combine_chunks().dictionary_encode().indices
|
||||
|
||||
if cat_columns:
|
||||
# DF doesn't have the cat attribute, as a result, we use apply here
|
||||
transformed[cat_columns] = (
|
||||
transformed[cat_columns]
|
||||
.apply(cat_codes)
|
||||
.astype(np.float32)
|
||||
.replace(-1.0, np.NaN)
|
||||
)
|
||||
if nul_columns:
|
||||
transformed[nul_columns] = transformed[nul_columns].astype(np.float32)
|
||||
|
||||
# TODO(jiamingy): Investigate the possibility of using dataframe protocol or arrow
|
||||
# IPC format for pandas so that we can apply the data transformation inside XGBoost
|
||||
# for better memory efficiency.
|
||||
|
||||
return transformed
|
||||
|
||||
|
||||
def pandas_ext_num_types(data: DataFrame) -> DataFrame:
|
||||
"""Experimental suppport for handling pyarrow extension numeric types."""
|
||||
def pandas_pa_type(ser: Any) -> np.ndarray:
|
||||
"""Handle pandas pyarrow extention."""
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
# No copy, callstack:
|
||||
# pandas.core.internals.managers.SingleBlockManager.array_values()
|
||||
# pandas.core.internals.blocks.EABackedBlock.values
|
||||
d_array: pd.arrays.ArrowExtensionArray = ser.array
|
||||
# no copy in __arrow_array__
|
||||
# ArrowExtensionArray._data is a chunked array
|
||||
aa: pa.ChunkedArray = d_array.__arrow_array__()
|
||||
# combine_chunks takes the most significant amount of time
|
||||
chunk: pa.Array = aa.combine_chunks()
|
||||
# When there's null value, we have to use copy
|
||||
zero_copy = chunk.null_count == 0
|
||||
# Alternately, we can use chunk.buffers(), which returns a list of buffers and
|
||||
# we need to concatenate them ourselves.
|
||||
# FIXME(jiamingy): Is there a better way to access the arrow buffer along with
|
||||
# its mask?
|
||||
# Buffers from chunk.buffers() have the address attribute, but don't expose the
|
||||
# mask.
|
||||
arr: np.ndarray = chunk.to_numpy(zero_copy_only=zero_copy, writable=False)
|
||||
arr, _ = _ensure_np_dtype(arr, arr.dtype)
|
||||
return arr
|
||||
|
||||
|
||||
def pandas_transform_data(data: DataFrame) -> List[np.ndarray]:
|
||||
"""Handle categorical dtype and extension types from pandas."""
|
||||
import pandas as pd
|
||||
from pandas import Float32Dtype, Float64Dtype
|
||||
|
||||
result: List[np.ndarray] = []
|
||||
|
||||
def cat_codes(ser: pd.Series) -> np.ndarray:
|
||||
if is_pd_cat_dtype(ser.dtype):
|
||||
return _ensure_np_dtype(
|
||||
ser.cat.codes.astype(np.float32)
|
||||
.replace(-1.0, np.NaN)
|
||||
.to_numpy(na_value=np.nan),
|
||||
np.float32,
|
||||
)[0]
|
||||
# Not yet supported, the index is not ordered for some reason. Alternately:
|
||||
# `combine_chunks().to_pandas().cat.codes`. The result is the same.
|
||||
assert is_pa_ext_categorical_dtype(ser.dtype)
|
||||
return (
|
||||
ser.array.__arrow_array__()
|
||||
.combine_chunks()
|
||||
.dictionary_encode()
|
||||
.indices.astype(np.float32)
|
||||
.replace(-1.0, np.NaN)
|
||||
)
|
||||
|
||||
def nu_type(ser: pd.Series) -> np.ndarray:
|
||||
# Avoid conversion when possible
|
||||
if isinstance(dtype, Float32Dtype):
|
||||
res_dtype: NumpyDType = np.float32
|
||||
elif isinstance(dtype, Float64Dtype):
|
||||
res_dtype = np.float64
|
||||
else:
|
||||
res_dtype = np.float32
|
||||
return _ensure_np_dtype(
|
||||
ser.to_numpy(dtype=res_dtype, na_value=np.nan), res_dtype
|
||||
)[0]
|
||||
|
||||
def oth_type(ser: pd.Series) -> np.ndarray:
|
||||
# The dtypes module is added in 1.25.
|
||||
npdtypes = np.lib.NumpyVersion(np.__version__) > np.lib.NumpyVersion("1.25.0")
|
||||
npdtypes = npdtypes and isinstance(
|
||||
ser.dtype,
|
||||
(
|
||||
# pylint: disable=no-member
|
||||
np.dtypes.Float32DType, # type: ignore
|
||||
# pylint: disable=no-member
|
||||
np.dtypes.Float64DType, # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
if npdtypes or dtype in {np.float32, np.float64}:
|
||||
array = ser.to_numpy()
|
||||
else:
|
||||
# Specifying the dtype can significantly slow down the conversion (about
|
||||
# 15% slow down for dense inplace-predict)
|
||||
array = ser.to_numpy(dtype=np.float32, na_value=np.nan)
|
||||
return _ensure_np_dtype(array, array.dtype)[0]
|
||||
|
||||
for col, dtype in zip(data.columns, data.dtypes):
|
||||
if not is_pa_ext_dtype(dtype):
|
||||
continue
|
||||
# No copy, callstack:
|
||||
# pandas.core.internals.managers.SingleBlockManager.array_values()
|
||||
# pandas.core.internals.blocks.EABackedBlock.values
|
||||
d_array: pd.arrays.ArrowExtensionArray = data[col].array
|
||||
# no copy in __arrow_array__
|
||||
# ArrowExtensionArray._data is a chunked array
|
||||
aa: pa.ChunkedArray = d_array.__arrow_array__()
|
||||
chunk: pa.Array = aa.combine_chunks()
|
||||
# Alternately, we can use chunk.buffers(), which returns a list of buffers and
|
||||
# we need to concatenate them ourselves.
|
||||
arr = chunk.__array__()
|
||||
data[col] = arr
|
||||
return data
|
||||
if is_pa_ext_categorical_dtype(dtype):
|
||||
raise ValueError(
|
||||
"pyarrow dictionary type is not supported. Use pandas category instead."
|
||||
)
|
||||
if is_pd_cat_dtype(dtype):
|
||||
result.append(cat_codes(data[col]))
|
||||
elif is_pa_ext_dtype(dtype):
|
||||
result.append(pandas_pa_type(data[col]))
|
||||
elif is_nullable_dtype(dtype):
|
||||
result.append(nu_type(data[col]))
|
||||
elif is_pd_sparse_dtype(dtype):
|
||||
arr = cast(pd.arrays.SparseArray, data[col].values)
|
||||
arr = arr.to_dense()
|
||||
if _is_np_array_like(arr):
|
||||
arr, _ = _ensure_np_dtype(arr, arr.dtype)
|
||||
result.append(arr)
|
||||
else:
|
||||
result.append(oth_type(data[col]))
|
||||
|
||||
# FIXME(jiamingy): Investigate the possibility of using dataframe protocol or arrow
|
||||
# IPC format for pandas so that we can apply the data transformation inside XGBoost
|
||||
# for better memory efficiency.
|
||||
return result
|
||||
|
||||
|
||||
def _transform_pandas_df(
|
||||
data: DataFrame,
|
||||
enable_categorical: bool,
|
||||
feature_names: Optional[FeatureNames] = None,
|
||||
feature_types: Optional[FeatureTypes] = None,
|
||||
meta: Optional[str] = None,
|
||||
meta_type: Optional[NumpyDType] = None,
|
||||
) -> Tuple[np.ndarray, Optional[FeatureNames], Optional[FeatureTypes]]:
|
||||
pyarrow_extension = False
|
||||
def pandas_check_dtypes(data: DataFrame, enable_categorical: bool) -> None:
|
||||
"""Validate the input types, returns True if the dataframe is backed by arrow."""
|
||||
sparse_extension = False
|
||||
|
||||
for dtype in data.dtypes:
|
||||
if not (
|
||||
(dtype.name in _pandas_dtype_mapper)
|
||||
@ -498,27 +536,65 @@ def _transform_pandas_df(
|
||||
or is_pa_ext_dtype(dtype)
|
||||
):
|
||||
_invalid_dataframe_dtype(data)
|
||||
if is_pa_ext_dtype(dtype):
|
||||
pyarrow_extension = True
|
||||
|
||||
if is_pd_sparse_dtype(dtype):
|
||||
sparse_extension = True
|
||||
|
||||
if sparse_extension:
|
||||
warnings.warn("Sparse arrays from pandas are converted into dense.")
|
||||
|
||||
|
||||
class PandasTransformed:
|
||||
"""A storage class for transformed pandas DataFrame."""
|
||||
|
||||
def __init__(self, columns: List[np.ndarray]) -> None:
|
||||
self.columns = columns
|
||||
|
||||
def array_interface(self) -> bytes:
|
||||
"""Return a byte string for JSON encoded array interface."""
|
||||
aitfs = list(map(_array_interface_dict, self.columns))
|
||||
sarrays = bytes(json.dumps(aitfs), "utf-8")
|
||||
return sarrays
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, int]:
|
||||
"""Return shape of the transformed DataFrame."""
|
||||
return self.columns[0].shape[0], len(self.columns)
|
||||
|
||||
|
||||
def _transform_pandas_df(
|
||||
data: DataFrame,
|
||||
enable_categorical: bool,
|
||||
feature_names: Optional[FeatureNames] = None,
|
||||
feature_types: Optional[FeatureTypes] = None,
|
||||
meta: Optional[str] = None,
|
||||
) -> Tuple[PandasTransformed, Optional[FeatureNames], Optional[FeatureTypes]]:
|
||||
pandas_check_dtypes(data, enable_categorical)
|
||||
if meta and len(data.columns) > 1 and meta not in _matrix_meta:
|
||||
raise ValueError(f"DataFrame for {meta} cannot have multiple columns")
|
||||
|
||||
feature_names, feature_types = pandas_feature_info(
|
||||
data, meta, feature_names, feature_types, enable_categorical
|
||||
)
|
||||
|
||||
transformed = pandas_cat_null(data)
|
||||
if pyarrow_extension:
|
||||
if transformed is data:
|
||||
transformed = data.copy(deep=False)
|
||||
transformed = pandas_ext_num_types(transformed)
|
||||
arrays = pandas_transform_data(data)
|
||||
return PandasTransformed(arrays), feature_names, feature_types
|
||||
|
||||
if meta and len(data.columns) > 1 and meta not in _matrix_meta:
|
||||
raise ValueError(f"DataFrame for {meta} cannot have multiple columns")
|
||||
|
||||
dtype = meta_type if meta_type else np.float32
|
||||
arr: np.ndarray = transformed.values
|
||||
if meta_type:
|
||||
arr = arr.astype(dtype)
|
||||
return arr, feature_names, feature_types
|
||||
def _meta_from_pandas_df(
|
||||
data: DataType,
|
||||
name: str,
|
||||
dtype: Optional[NumpyDType],
|
||||
handle: ctypes.c_void_p,
|
||||
) -> None:
|
||||
data, _, _ = _transform_pandas_df(data, False, meta=name)
|
||||
if len(data.columns) == 1:
|
||||
array = data.columns[0]
|
||||
else:
|
||||
array = np.stack(data.columns).T
|
||||
|
||||
array, dtype = _ensure_np_dtype(array, dtype)
|
||||
_meta_from_numpy(array, name, dtype, handle)
|
||||
|
||||
|
||||
def _from_pandas_df(
|
||||
@ -530,12 +606,21 @@ def _from_pandas_df(
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
data, feature_names, feature_types = _transform_pandas_df(
|
||||
df, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
return _from_numpy_array(
|
||||
data, missing, nthread, feature_names, feature_types, data_split_mode
|
||||
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromColumnar(
|
||||
df.array_interface(),
|
||||
make_jcargs(
|
||||
nthread=nthread, missing=missing, data_split_mode=data_split_mode
|
||||
),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
def _is_pandas_series(data: DataType) -> bool:
|
||||
@ -550,7 +635,12 @@ def _meta_from_pandas_series(
|
||||
data: DataType, name: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
|
||||
) -> None:
|
||||
"""Help transform pandas series for meta data like labels"""
|
||||
data = data.values.astype("float")
|
||||
if is_pd_sparse_dtype(data.dtype):
|
||||
data = data.values.to_dense().astype(np.float32)
|
||||
elif is_pa_ext_dtype(data.dtype):
|
||||
data = pandas_pa_type(data)
|
||||
else:
|
||||
data = data.to_numpy(np.float32, na_value=np.nan)
|
||||
|
||||
if is_pd_sparse_dtype(getattr(data, "dtype", data)):
|
||||
data = data.to_dense() # type: ignore
|
||||
@ -732,6 +822,8 @@ def _arrow_transform(data: DataType) -> Any:
|
||||
return pd.ArrowDtype(pa.bool_())
|
||||
return None
|
||||
|
||||
# For common cases, this is zero-copy, can check with:
|
||||
# pa.total_allocated_bytes()
|
||||
df = data.to_pandas(types_mapper=type_mapper)
|
||||
return df
|
||||
|
||||
@ -859,11 +951,10 @@ def _from_cudf_df(
|
||||
)
|
||||
interfaces_str = _cudf_array_interfaces(data, cat_codes)
|
||||
handle = ctypes.c_void_p()
|
||||
config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromCudaColumnar(
|
||||
interfaces_str,
|
||||
config,
|
||||
make_jcargs(nthread=nthread, missing=missing),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
@ -1221,8 +1312,7 @@ def dispatch_meta_backend(
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_df(data):
|
||||
data, _, _ = _transform_pandas_df(data, False, meta=name, meta_type=dtype)
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
_meta_from_pandas_df(data, name, dtype=dtype, handle=handle)
|
||||
return
|
||||
if _is_pandas_series(data):
|
||||
_meta_from_pandas_series(data, name, dtype, handle)
|
||||
@ -1244,8 +1334,7 @@ def dispatch_meta_backend(
|
||||
_meta_from_dt(data, name, dtype, handle)
|
||||
return
|
||||
if _is_modin_df(data):
|
||||
data, _, _ = _transform_pandas_df(data, False, meta=name, meta_type=dtype)
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
_meta_from_pandas_df(data, name, dtype=dtype, handle=handle)
|
||||
return
|
||||
if _is_modin_series(data):
|
||||
data = data.values.astype("float")
|
||||
@ -1317,11 +1406,10 @@ def _proxy_transform(
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_df(data):
|
||||
arr, feature_names, feature_types = _transform_pandas_df(
|
||||
df, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
arr, _ = _ensure_np_dtype(arr, arr.dtype)
|
||||
return arr, None, feature_names, feature_types
|
||||
return df, None, feature_names, feature_types
|
||||
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))
|
||||
|
||||
|
||||
@ -1356,6 +1444,9 @@ def dispatch_proxy_set_data(
|
||||
if not allow_host:
|
||||
raise err
|
||||
|
||||
if isinstance(data, PandasTransformed):
|
||||
proxy._set_data_from_pandas(data) # pylint: disable=W0212
|
||||
return
|
||||
if _is_np_array_like(data):
|
||||
_check_data_shape(data)
|
||||
proxy._set_data_from_array(data) # pylint: disable=W0212
|
||||
|
||||
@ -361,49 +361,57 @@ XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle *out) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int
|
||||
XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle,
|
||||
char const *c_interface_str) {
|
||||
XGB_DLL int XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle,
|
||||
char const *c_interface_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(c_interface_str);
|
||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
CHECK(p_m);
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
|
||||
CHECK(m) << "Current DMatrix type does not support set data.";
|
||||
m->SetCUDAArray(c_interface_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle,
|
||||
char const *c_interface_str) {
|
||||
XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle, char const *c_interface_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(c_interface_str);
|
||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
CHECK(p_m);
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
|
||||
CHECK(m) << "Current DMatrix type does not support set data.";
|
||||
m->SetCUDAArray(c_interface_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGProxyDMatrixSetDataDense(DMatrixHandle handle,
|
||||
char const *c_interface_str) {
|
||||
XGB_DLL int XGProxyDMatrixSetDataColumnar(DMatrixHandle handle, char const *c_interface_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(c_interface_str);
|
||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
CHECK(p_m);
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
|
||||
CHECK(m) << "Current DMatrix type does not support set data.";
|
||||
m->SetColumnarData(c_interface_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGProxyDMatrixSetDataDense(DMatrixHandle handle, char const *c_interface_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(c_interface_str);
|
||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
CHECK(p_m);
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
|
||||
CHECK(m) << "Current DMatrix type does not support set data.";
|
||||
m->SetArrayData(c_interface_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
|
||||
char const *indices, char const *data,
|
||||
xgboost::bst_ulong ncol) {
|
||||
XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, char const *indices,
|
||||
char const *data, xgboost::bst_ulong ncol) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(indptr);
|
||||
@ -411,7 +419,7 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
|
||||
xgboost_CHECK_C_ARG_PTR(data);
|
||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
CHECK(p_m);
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
|
||||
CHECK(m) << "Current DMatrix type does not support set data.";
|
||||
m->SetCSRData(indptr, indices, data, ncol, true);
|
||||
API_END();
|
||||
@ -429,6 +437,25 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indic
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromColumnar(char const *data, char const *c_json_config,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||
xgboost_CHECK_C_ARG_PTR(data);
|
||||
|
||||
auto config = Json::Load(c_json_config);
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(config, "nthread", 0);
|
||||
auto data_split_mode =
|
||||
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
|
||||
|
||||
data::ColumnarAdapter adapter{data};
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char const *data,
|
||||
xgboost::bst_ulong ncol, char const *c_json_config,
|
||||
DMatrixHandle *out) {
|
||||
@ -1196,6 +1223,27 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *array_in
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::shared_ptr<DMatrix> p_m{nullptr};
|
||||
if (!m) {
|
||||
p_m.reset(new data::DMatrixProxy);
|
||||
} else {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
||||
CHECK(proxy) << "Invalid input type for inplace predict.";
|
||||
xgboost_CHECK_C_ARG_PTR(array_interface);
|
||||
proxy->SetColumnarData(array_interface);
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, char const *indices,
|
||||
char const *data, xgboost::bst_ulong cols,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
|
||||
@ -97,6 +97,7 @@ void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid
|
||||
// the nnz from info is not reliable as sketching might be the first place to go through
|
||||
// the data.
|
||||
auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
|
||||
CHECK(!this->columns_size_.empty());
|
||||
this->PushRowPageImpl(batch, base_rowid, weights, info.num_nonzero_, info.num_col_, is_dense,
|
||||
is_valid);
|
||||
}
|
||||
@ -110,6 +111,7 @@ INSTANTIATE(CSRArrayAdapterBatch)
|
||||
INSTANTIATE(CSCAdapterBatch)
|
||||
INSTANTIATE(DataTableAdapterBatch)
|
||||
INSTANTIATE(SparsePageAdapterBatch)
|
||||
INSTANTIATE(ColumnarAdapterBatch)
|
||||
|
||||
namespace {
|
||||
/**
|
||||
|
||||
@ -25,9 +25,7 @@
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/string_view.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
namespace xgboost::data {
|
||||
/** External data formats should implement an adapter as below. The
|
||||
* adapter provides a uniform access to data outside xgboost, allowing
|
||||
* construction of DMatrix objects from a range of sources without duplicating
|
||||
@ -279,9 +277,9 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
return Line{array_interface_, idx};
|
||||
}
|
||||
|
||||
size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
size_t NumCols() const { return array_interface_.Shape(1); }
|
||||
size_t Size() const { return this->NumRows(); }
|
||||
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
[[nodiscard]] std::size_t NumCols() const { return array_interface_.Shape(1); }
|
||||
[[nodiscard]] std::size_t Size() const { return this->NumRows(); }
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface<2> array_interface)
|
||||
: array_interface_{std::move(array_interface)} {}
|
||||
@ -326,11 +324,11 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
|
||||
offset_{offset} {}
|
||||
|
||||
COOTuple GetElement(std::size_t idx) const {
|
||||
[[nodiscard]] COOTuple GetElement(std::size_t idx) const {
|
||||
return {ridx_, TypedIndex<std::size_t, 1>{indices_}(offset_ + idx), values_(offset_ + idx)};
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
[[nodiscard]] std::size_t Size() const {
|
||||
return values_.Shape(0);
|
||||
}
|
||||
};
|
||||
@ -539,9 +537,11 @@ class CSCArrayAdapter : public detail::SingleBatchDataIter<CSCArrayAdapterBatch>
|
||||
batch_{CSCArrayAdapterBatch{indptr_, indices_, values_}} {}
|
||||
|
||||
// JVM package sends 0 as unknown
|
||||
size_t NumRows() const { return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_; }
|
||||
size_t NumColumns() const { return indptr_.n - 1; }
|
||||
const CSCArrayAdapterBatch& Value() const override { return batch_; }
|
||||
[[nodiscard]] std::size_t NumRows() const {
|
||||
return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_;
|
||||
}
|
||||
[[nodiscard]] std::size_t NumColumns() const { return indptr_.n - 1; }
|
||||
[[nodiscard]] const CSCArrayAdapterBatch& Value() const override { return batch_; }
|
||||
};
|
||||
|
||||
class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
@ -634,15 +634,15 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
Line(std::size_t ridx, void const* const* const data, std::vector<DTType> const& ft)
|
||||
: row_idx_{ridx}, data_{data}, feature_types_{ft} {}
|
||||
std::size_t Size() const { return feature_types_.size(); }
|
||||
COOTuple GetElement(std::size_t idx) const {
|
||||
[[nodiscard]] std::size_t Size() const { return feature_types_.size(); }
|
||||
[[nodiscard]] COOTuple GetElement(std::size_t idx) const {
|
||||
return COOTuple{row_idx_, idx, DTGetValue(data_[idx], feature_types_[idx], row_idx_)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
size_t Size() const { return num_rows_; }
|
||||
const Line GetLine(std::size_t ridx) const { return {ridx, data_, feature_types_}; }
|
||||
[[nodiscard]] size_t Size() const { return num_rows_; }
|
||||
[[nodiscard]] const Line GetLine(std::size_t ridx) const { return {ridx, data_, feature_types_}; }
|
||||
static constexpr bool kIsRowMajor = true;
|
||||
|
||||
private:
|
||||
@ -659,9 +659,9 @@ class DataTableAdapter : public detail::SingleBatchDataIter<DataTableAdapterBatc
|
||||
: batch_(data, feature_stypes, num_rows, num_features),
|
||||
num_rows_(num_rows),
|
||||
num_columns_(num_features) {}
|
||||
const DataTableAdapterBatch& Value() const override { return batch_; }
|
||||
std::size_t NumRows() const { return num_rows_; }
|
||||
std::size_t NumColumns() const { return num_columns_; }
|
||||
[[nodiscard]] const DataTableAdapterBatch& Value() const override { return batch_; }
|
||||
[[nodiscard]] std::size_t NumRows() const { return num_rows_; }
|
||||
[[nodiscard]] std::size_t NumColumns() const { return num_columns_; }
|
||||
|
||||
private:
|
||||
DataTableAdapterBatch batch_;
|
||||
@ -669,6 +669,74 @@ class DataTableAdapter : public detail::SingleBatchDataIter<DataTableAdapterBatc
|
||||
std::size_t num_columns_;
|
||||
};
|
||||
|
||||
class ColumnarAdapterBatch : public detail::NoMetaInfo {
|
||||
common::Span<ArrayInterface<1, false>> columns_;
|
||||
|
||||
class Line {
|
||||
common::Span<ArrayInterface<1, false>> const& columns_;
|
||||
std::size_t ridx_;
|
||||
|
||||
public:
|
||||
explicit Line(common::Span<ArrayInterface<1, false>> const& columns, std::size_t ridx)
|
||||
: columns_{columns}, ridx_{ridx} {}
|
||||
[[nodiscard]] std::size_t Size() const { return columns_.empty() ? 0 : columns_.size(); }
|
||||
|
||||
[[nodiscard]] COOTuple GetElement(std::size_t idx) const {
|
||||
return {ridx_, idx, columns_[idx](ridx_)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
ColumnarAdapterBatch() = default;
|
||||
explicit ColumnarAdapterBatch(common::Span<ArrayInterface<1, false>> columns)
|
||||
: columns_{columns} {}
|
||||
[[nodiscard]] Line GetLine(std::size_t ridx) const { return Line{columns_, ridx}; }
|
||||
[[nodiscard]] std::size_t Size() const {
|
||||
return columns_.empty() ? 0 : columns_.front().Shape(0);
|
||||
}
|
||||
[[nodiscard]] std::size_t NumCols() const { return columns_.empty() ? 0 : columns_.size(); }
|
||||
[[nodiscard]] std::size_t NumRows() const { return this->Size(); }
|
||||
|
||||
static constexpr bool kIsRowMajor = true;
|
||||
};
|
||||
|
||||
class ColumnarAdapter : public detail::SingleBatchDataIter<ColumnarAdapterBatch> {
|
||||
std::vector<ArrayInterface<1, false>> columns_;
|
||||
ColumnarAdapterBatch batch_;
|
||||
|
||||
public:
|
||||
explicit ColumnarAdapter(StringView columns) {
|
||||
auto jarray = Json::Load(columns);
|
||||
CHECK(IsA<Array>(jarray));
|
||||
auto const& array = get<Array const>(jarray);
|
||||
for (auto col : array) {
|
||||
columns_.emplace_back(get<Object const>(col));
|
||||
}
|
||||
bool consistent =
|
||||
columns_.empty() ||
|
||||
std::all_of(columns_.cbegin(), columns_.cend(), [&](ArrayInterface<1, false> const& array) {
|
||||
return array.Shape(0) == columns_[0].Shape(0);
|
||||
});
|
||||
CHECK(consistent) << "Size of columns should be the same.";
|
||||
batch_ = ColumnarAdapterBatch{columns_};
|
||||
}
|
||||
|
||||
[[nodiscard]] ColumnarAdapterBatch const& Value() const override { return batch_; }
|
||||
|
||||
[[nodiscard]] std::size_t NumRows() const {
|
||||
if (!columns_.empty()) {
|
||||
return columns_.front().shape[0];
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
[[nodiscard]] std::size_t NumColumns() const {
|
||||
if (!columns_.empty()) {
|
||||
return columns_.size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
class FileAdapterBatch {
|
||||
public:
|
||||
class Line {
|
||||
@ -851,6 +919,5 @@ class SparsePageAdapterBatch {
|
||||
Line GetLine(size_t ridx) const { return Line{page_[ridx].data(), page_[ridx].size(), ridx}; }
|
||||
size_t Size() const { return page_.Size(); }
|
||||
};
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_ADAPTER_H_
|
||||
|
||||
@ -947,38 +947,24 @@ DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const st
|
||||
return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode);
|
||||
}
|
||||
|
||||
template DMatrix* DMatrix::Create<data::DenseAdapter>(data::DenseAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::ArrayAdapter>(data::ArrayAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::CSRAdapter>(data::CSRAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::CSCAdapter>(data::CSCAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::DataTableAdapter>(data::DataTableAdapter* adapter,
|
||||
float missing, std::int32_t nthread,
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::FileAdapter>(data::FileAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(data::CSRArrayAdapter* adapter,
|
||||
float missing, std::int32_t nthread,
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::CSCArrayAdapter>(data::CSCArrayAdapter* adapter,
|
||||
float missing, std::int32_t nthread,
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
// Instantiate the factory function for various adapters
|
||||
#define INSTANTIATION_CREATE(_AdapterT) \
|
||||
template DMatrix* DMatrix::Create<data::_AdapterT>( \
|
||||
data::_AdapterT * adapter, float missing, std::int32_t nthread, \
|
||||
const std::string& cache_prefix, DataSplitMode data_split_mode);
|
||||
|
||||
INSTANTIATION_CREATE(DenseAdapter)
|
||||
INSTANTIATION_CREATE(ArrayAdapter)
|
||||
INSTANTIATION_CREATE(CSRAdapter)
|
||||
INSTANTIATION_CREATE(CSCAdapter)
|
||||
INSTANTIATION_CREATE(DataTableAdapter)
|
||||
INSTANTIATION_CREATE(FileAdapter)
|
||||
INSTANTIATION_CREATE(CSRArrayAdapter)
|
||||
INSTANTIATION_CREATE(CSCArrayAdapter)
|
||||
INSTANTIATION_CREATE(ColumnarAdapter)
|
||||
|
||||
#undef INSTANTIATION_CREATE
|
||||
|
||||
template DMatrix* DMatrix::Create(
|
||||
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||
float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode);
|
||||
@ -1156,7 +1142,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
builder.InitStorage();
|
||||
|
||||
// Second pass over batch, placing elements in correct position
|
||||
|
||||
auto is_valid = data::IsValidFunctor{missing};
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
@ -1253,9 +1238,10 @@ template uint64_t SparsePage::Push(const data::CSCAdapterBatch& batch, float mis
|
||||
template uint64_t SparsePage::Push(const data::DataTableAdapterBatch& batch, float missing,
|
||||
int nthread);
|
||||
template uint64_t SparsePage::Push(const data::FileAdapterBatch& batch, float missing, int nthread);
|
||||
template uint64_t SparsePage::Push(const data::ColumnarAdapterBatch& batch, float missing,
|
||||
std::int32_t nthread);
|
||||
|
||||
namespace data {
|
||||
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format);
|
||||
DMLC_REGISTRY_LINK_TAG(gradient_index_format);
|
||||
|
||||
@ -120,7 +120,7 @@ void GHistIndexMatrix::PushAdapterBatchColumns(Context const *ctx, Batch const &
|
||||
INSTANTIATION_PUSH(data::CSRArrayAdapterBatch)
|
||||
INSTANTIATION_PUSH(data::ArrayAdapterBatch)
|
||||
INSTANTIATION_PUSH(data::SparsePageAdapterBatch)
|
||||
|
||||
INSTANTIATION_PUSH(data::ColumnarAdapterBatch)
|
||||
#undef INSTANTIATION_PUSH
|
||||
|
||||
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
|
||||
|
||||
@ -5,7 +5,22 @@
|
||||
|
||||
#include "proxy_dmatrix.h"
|
||||
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for DMatrix
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::data {
|
||||
void DMatrixProxy::SetColumnarData(StringView interface_str) {
|
||||
std::shared_ptr<ColumnarAdapter> adapter{new ColumnarAdapter{interface_str}};
|
||||
this->batch_ = adapter;
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
this->ctx_.Init(Args{{"device", "cpu"}});
|
||||
}
|
||||
|
||||
void DMatrixProxy::SetArrayData(StringView interface_str) {
|
||||
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter{interface_str}};
|
||||
this->batch_ = adapter;
|
||||
|
||||
@ -62,6 +62,8 @@ class DMatrixProxy : public DMatrix {
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
void SetColumnarData(StringView interface_str);
|
||||
|
||||
void SetArrayData(StringView interface_str);
|
||||
void SetCSRData(char const* c_indptr, char const* c_indices, char const* c_values,
|
||||
bst_feature_t n_features, bool on_host);
|
||||
@ -151,6 +153,17 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
|
||||
if (type_error) {
|
||||
*type_error = false;
|
||||
}
|
||||
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<ColumnarAdapter>)) {
|
||||
if constexpr (get_value) {
|
||||
auto value = std::any_cast<std::shared_ptr<ColumnarAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
} else {
|
||||
auto value = std::any_cast<std::shared_ptr<ColumnarAdapter>>(proxy->Adapter());
|
||||
return fn(value);
|
||||
}
|
||||
if (type_error) {
|
||||
*type_error = false;
|
||||
}
|
||||
} else {
|
||||
if (type_error) {
|
||||
*type_error = true;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014~2023 by XGBoost Contributors
|
||||
* Copyright 2014~2023, XGBoost Contributors
|
||||
* \file simple_dmatrix.cc
|
||||
* \brief the input data structure for gradient boosting
|
||||
* \author Tianqi Chen
|
||||
@ -356,6 +356,8 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
||||
DataSplitMode data_split_mode);
|
||||
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread,
|
||||
DataSplitMode data_split_mode);
|
||||
template SimpleDMatrix::SimpleDMatrix(ColumnarAdapter* adapter, float missing, int nthread,
|
||||
DataSplitMode data_split_mode);
|
||||
template SimpleDMatrix::SimpleDMatrix(
|
||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||
float missing, int nthread, DataSplitMode data_split_mode);
|
||||
|
||||
@ -761,6 +761,9 @@ class CPUPredictor : public Predictor {
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter, 1>(x, p_m, model, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::ColumnarAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::ColumnarAdapter, kBlockOfRowsSize>(
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -16,7 +16,7 @@ pytestmark = pytest.mark.skipif(**tm.no_modin())
|
||||
|
||||
class TestModin:
|
||||
@pytest.mark.xfail
|
||||
def test_modin(self):
|
||||
def test_modin(self) -> None:
|
||||
df = md.DataFrame([[1, 2., True], [2, 3., False]],
|
||||
columns=['a', 'b', 'c'])
|
||||
dm = xgb.DMatrix(df, label=md.Series([1, 2]))
|
||||
@ -67,8 +67,8 @@ class TestModin:
|
||||
enable_categorical=False)
|
||||
exp = np.array([[1., 1., 0., 0.],
|
||||
[2., 0., 1., 0.],
|
||||
[3., 0., 0., 1.]])
|
||||
np.testing.assert_array_equal(result, exp)
|
||||
[3., 0., 0., 1.]]).T
|
||||
np.testing.assert_array_equal(result.columns, exp)
|
||||
dm = xgb.DMatrix(dummies)
|
||||
assert dm.feature_names == ['B', 'A_X', 'A_Y', 'A_Z']
|
||||
assert dm.feature_types == ['int', 'int', 'int', 'int']
|
||||
@ -108,20 +108,23 @@ class TestModin:
|
||||
|
||||
def test_modin_label(self):
|
||||
# label must be a single column
|
||||
df = md.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||
df = md.DataFrame({"A": ["X", "Y", "Z"], "B": [1, 2, 3]})
|
||||
with pytest.raises(ValueError):
|
||||
xgb.data._transform_pandas_df(df, False, None, None, 'label', 'float')
|
||||
xgb.data._transform_pandas_df(df, False, None, None, "label")
|
||||
|
||||
# label must be supported dtype
|
||||
df = md.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
||||
df = md.DataFrame({"A": np.array(["a", "b", "c"], dtype=object)})
|
||||
with pytest.raises(ValueError):
|
||||
xgb.data._transform_pandas_df(df, False, None, None, 'label', 'float')
|
||||
xgb.data._transform_pandas_df(df, False, None, None, "label")
|
||||
|
||||
df = md.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
||||
result, _, _ = xgb.data._transform_pandas_df(df, False, None, None,
|
||||
'label', 'float')
|
||||
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
||||
dtype=float))
|
||||
df = md.DataFrame({"A": np.array([1, 2, 3], dtype=int)})
|
||||
result, _, _ = xgb.data._transform_pandas_df(
|
||||
df, False, None, None, "label"
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
np.stack(result.columns, axis=1),
|
||||
np.array([[1.0], [2.0], [3.0]], dtype=float),
|
||||
)
|
||||
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
|
||||
@ -105,8 +105,8 @@ class TestPandas:
|
||||
result, _, _ = xgb.data._transform_pandas_df(dummies, enable_categorical=False)
|
||||
exp = np.array(
|
||||
[[1.0, 1.0, 0.0, 0.0], [2.0, 0.0, 1.0, 0.0], [3.0, 0.0, 0.0, 1.0]]
|
||||
)
|
||||
np.testing.assert_array_equal(result, exp)
|
||||
).T
|
||||
np.testing.assert_array_equal(result.columns, exp)
|
||||
dm = xgb.DMatrix(dummies, data_split_mode=data_split_mode)
|
||||
assert dm.num_row() == 3
|
||||
if data_split_mode == DataSplitMode.ROW:
|
||||
@ -202,6 +202,20 @@ class TestPandas:
|
||||
else:
|
||||
assert dm.num_col() == 1 * world_size
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_multi_target(self) -> None:
|
||||
from sklearn.datasets import make_regression
|
||||
|
||||
X, y = make_regression(n_samples=1024, n_features=4, n_targets=3)
|
||||
ydf = pd.DataFrame({i: y[:, i] for i in range(y.shape[1])})
|
||||
|
||||
Xy = xgb.DMatrix(X, ydf)
|
||||
assert Xy.num_row() == y.shape[0]
|
||||
assert Xy.get_label().size == y.shape[0] * y.shape[1]
|
||||
Xy = xgb.QuantileDMatrix(X, ydf)
|
||||
assert Xy.num_row() == y.shape[0]
|
||||
assert Xy.get_label().size == y.shape[0] * y.shape[1]
|
||||
|
||||
def test_slice(self):
|
||||
rng = np.random.RandomState(1994)
|
||||
rows = 100
|
||||
@ -233,13 +247,14 @@ class TestPandas:
|
||||
X, enable_categorical=True
|
||||
)
|
||||
|
||||
assert transformed[:, 0].min() == 0
|
||||
assert transformed.columns[0].min() == 0
|
||||
|
||||
# test missing value
|
||||
X = pd.DataFrame({"f0": ["a", "b", np.NaN]})
|
||||
X["f0"] = X["f0"].astype("category")
|
||||
arr, _, _ = xgb.data._transform_pandas_df(X, enable_categorical=True)
|
||||
assert not np.any(arr == -1.0)
|
||||
for c in arr.columns:
|
||||
assert not np.any(c == -1.0)
|
||||
|
||||
X = X["f0"]
|
||||
y = y[: X.shape[0]]
|
||||
@ -273,24 +288,25 @@ class TestPandas:
|
||||
predt_dense = booster.predict(xgb.DMatrix(X.sparse.to_dense()))
|
||||
np.testing.assert_allclose(predt_sparse, predt_dense)
|
||||
|
||||
def test_pandas_label(self, data_split_mode=DataSplitMode.ROW):
|
||||
def test_pandas_label(
|
||||
self, data_split_mode: DataSplitMode = DataSplitMode.ROW
|
||||
) -> None:
|
||||
world_size = xgb.collective.get_world_size()
|
||||
# label must be a single column
|
||||
df = pd.DataFrame({"A": ["X", "Y", "Z"], "B": [1, 2, 3]})
|
||||
with pytest.raises(ValueError):
|
||||
xgb.data._transform_pandas_df(df, False, None, None, "label", "float")
|
||||
xgb.data._transform_pandas_df(df, False, None, None, "label")
|
||||
|
||||
# label must be supported dtype
|
||||
df = pd.DataFrame({"A": np.array(["a", "b", "c"], dtype=object)})
|
||||
with pytest.raises(ValueError):
|
||||
xgb.data._transform_pandas_df(df, False, None, None, "label", "float")
|
||||
xgb.data._transform_pandas_df(df, False, None, None, "label")
|
||||
|
||||
df = pd.DataFrame({"A": np.array([1, 2, 3], dtype=int)})
|
||||
result, _, _ = xgb.data._transform_pandas_df(
|
||||
df, False, None, None, "label", "float"
|
||||
)
|
||||
result, _, _ = xgb.data._transform_pandas_df(df, False, None, None, "label")
|
||||
np.testing.assert_array_equal(
|
||||
result, np.array([[1.0], [2.0], [3.0]], dtype=float)
|
||||
np.stack(result.columns, axis=1),
|
||||
np.array([[1.0], [2.0], [3.0]], dtype=float),
|
||||
)
|
||||
dm = xgb.DMatrix(
|
||||
np.random.randn(3, 2), label=df, data_split_mode=data_split_mode
|
||||
@ -507,6 +523,35 @@ class TestPandas:
|
||||
np.testing.assert_allclose(m_orig.get_label(), m_etype.get_label())
|
||||
np.testing.assert_allclose(m_etype.get_label(), y.values)
|
||||
|
||||
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
|
||||
def test_mixed_type(self, DMatrixT: Type[xgb.DMatrix]) -> None:
|
||||
f0 = np.arange(0, 4)
|
||||
f1 = pd.Series(f0, dtype="int64[pyarrow]")
|
||||
f2l = list(f0)
|
||||
f2l[0] = pd.NA
|
||||
f2 = pd.Series(f2l, dtype=pd.Int64Dtype())
|
||||
|
||||
df = pd.DataFrame({"f0": f0})
|
||||
df["f2"] = f2
|
||||
|
||||
m = DMatrixT(df)
|
||||
assert m.num_col() == df.shape[1]
|
||||
|
||||
df["f1"] = f1
|
||||
m = DMatrixT(df)
|
||||
assert m.num_col() == df.shape[1]
|
||||
assert m.num_row() == df.shape[0]
|
||||
assert m.num_nonmissing() == df.size - 1
|
||||
assert m.feature_names == list(map(str, df.columns))
|
||||
assert m.feature_types == ["int"] * df.shape[1]
|
||||
|
||||
y = f0
|
||||
m.set_info(label=y)
|
||||
booster = xgb.train({}, m)
|
||||
p0 = booster.inplace_predict(df)
|
||||
p1 = booster.predict(m)
|
||||
np.testing.assert_allclose(p0, p1)
|
||||
|
||||
@pytest.mark.skipif(tm.is_windows(), reason="Rabit does not run on windows")
|
||||
def test_pandas_column_split(self):
|
||||
tm.run_with_rabit(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user