Fix metainfo from DataFrame. (#5216)

* Fix metainfo from DataFrame.

* Unify helper functions for data and meta.
This commit is contained in:
Jiaming Yuan 2020-01-22 16:29:44 +08:00 committed by GitHub
parent 5d4c24a1fc
commit 1891cc766d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 225 additions and 237 deletions

View File

@ -231,10 +231,11 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64':
'int', 'uint8': 'int', 'uint16': 'int', 'uint32': 'int',
'uint64': 'int', 'float16': 'float', 'float32': 'float',
'float64': 'float', 'bool': 'i'}
# Either object has cuda array interface or contains columns with interfaces
def _has_cuda_array_interface(data):
@ -242,7 +243,8 @@ def _has_cuda_array_interface(data):
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
def _maybe_pandas_data(data, feature_names, feature_types):
def _maybe_pandas_data(data, feature_names, feature_types,
meta=None, meta_type=None):
"""Extract internal data from pd.DataFrame for DMatrix data"""
if not (PANDAS_INSTALLED and isinstance(data, DataFrame)):
@ -250,51 +252,41 @@ def _maybe_pandas_data(data, feature_names, feature_types):
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
bad_fields = [str(data.columns[i]) for i, dtype in
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
bad_fields = [
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
if dtype.name not in PANDAS_DTYPE_MAPPER
]
msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None:
if feature_names is None and meta is None:
if isinstance(data.columns, MultiIndex):
feature_names = [
' '.join([str(x) for x in i])
for i in data.columns
' '.join([str(x) for x in i]) for i in data.columns
]
elif isinstance(data.columns, Int64Index):
feature_names = list(map(str, data.columns))
else:
feature_names = data.columns.format()
if feature_types is None:
feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes]
if feature_types is None and meta is None:
feature_types = [
PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes
]
data = data.values.astype('float')
if meta and len(data.columns) > 1:
raise ValueError(
'DataFrame for {meta} cannot have multiple columns'.format(
meta=meta))
dtype = meta_type if meta_type else 'float'
data = data.values.astype(dtype)
return data, feature_names, feature_types
def _maybe_pandas_label(label):
"""Extract internal data from pd.DataFrame for DMatrix label."""
if PANDAS_INSTALLED and isinstance(label, DataFrame):
if len(label.columns) > 1:
raise ValueError(
'DataFrame for label cannot have multiple columns')
label_dtypes = label.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER
for dtype in label_dtypes):
raise ValueError(
'DataFrame.dtypes for label must be int, float or bool')
label = label.values.astype('float')
# pd.Series can be passed to xgb as it is
return label
def _maybe_cudf_dataframe(data, feature_names, feature_types):
"""Extract internal data from cudf.DataFrame for DMatrix data."""
if not (CUDF_INSTALLED and isinstance(data,
@ -324,11 +316,21 @@ DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
def _maybe_dt_data(data, feature_names, feature_types):
def _maybe_dt_data(data, feature_names, feature_types,
meta=None, meta_type=None):
"""Validate feature names and types if data table"""
if not isinstance(data, DataTable):
return data, feature_names, feature_types
if meta and data.shape[1] > 1:
raise ValueError(
'DataTable for label or weight cannot have multiple columns')
if meta:
# below requires new dt version
# extract first column
data = data.to_numpy()[:, 0].astype(meta_type)
return data, None, None
data_types_names = tuple(lt.name for lt in data.ltypes)
bad_fields = [data.names[i]
for i, type_name in enumerate(data_types_names)
@ -338,40 +340,31 @@ def _maybe_dt_data(data, feature_names, feature_types):
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None:
if feature_names is None and meta is None:
feature_names = data.names
# always return stypes for dt ingestion
if feature_types is not None:
raise ValueError('DataTable has own feature types, cannot pass them in')
raise ValueError(
'DataTable has own feature types, cannot pass them in.')
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
return data, feature_names, feature_types
def _maybe_dt_array(array):
"""Extract numpy array from single column data table"""
if not isinstance(array, DataTable) or array is None:
return array
if array.shape[1] > 1:
raise ValueError('DataTable for label or weight cannot have multiple columns')
# below requires new dt version
# extract first column
array = array.to_numpy()[:, 0].astype('float')
return array
def _convert_dataframes(data, feature_names, feature_types):
def _convert_dataframes(data, feature_names, feature_types,
meta=None, meta_type=None):
data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names,
feature_types)
feature_types,
meta,
meta_type)
data, feature_names, feature_types = _maybe_dt_data(data,
feature_names,
feature_types)
feature_types,
meta,
meta_type)
data, feature_names, feature_types = _maybe_cudf_dataframe(
data, feature_names, feature_types)
@ -379,6 +372,23 @@ def _convert_dataframes(data, feature_names, feature_types):
return data, feature_names, feature_types
def _maybe_np_slice(data, dtype=np.float32):
'''Handle numpy slice. This can be removed if we use __array_interface__.
'''
try:
if not data.flags.c_contiguous:
warnings.warn(
"Use subset (sliced data) of np.ndarray is not recommended " +
"because it will generate extra copies and increase " +
"memory consumption")
data = np.array(data, copy=True, dtype=dtype)
else:
data = np.array(data, copy=False, dtype=dtype)
except AttributeError:
data = np.array(data, copy=False, dtype=dtype)
return data
class DMatrix(object):
"""Data Matrix used in XGBoost.
@ -415,10 +425,10 @@ class DMatrix(object):
.. note:: For ranking task, weights are per-group.
In ranking task, one weight is assigned to each group (not each data
point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign
weights to individual data points.
In ranking task, one weight is assigned to each group (not each
data point). This is because we only care about the relative
ordering of data points within each group, so it doesn't make
sense to assign weights to individual data points.
silent : boolean, optional
Whether print messages during construction
@ -429,6 +439,7 @@ class DMatrix(object):
nthread : integer, optional
Number of threads to use for loading data from numpy array. If -1,
uses maximum threads available on the system.
"""
# force into void_p, mac need to pass things in as void_p
if data is None:
@ -447,11 +458,6 @@ class DMatrix(object):
data, feature_names, feature_types
)
label = _maybe_pandas_label(label)
label = _maybe_dt_array(label)
weight = _maybe_dt_array(weight)
base_margin = _maybe_dt_array(base_margin)
if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
@ -491,41 +497,47 @@ class DMatrix(object):
def _init_from_csr(self, csr):
"""Initialize data from a CSR matrix."""
if len(csr.indices) != len(csr.data):
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
raise ValueError('length mismatch: {} vs {}'.format(
len(csr.indices), len(csr.data)))
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr),
c_array(ctypes.c_uint, csr.indices),
c_array(ctypes.c_float, csr.data),
ctypes.c_size_t(len(csr.indptr)),
ctypes.c_size_t(len(csr.data)),
ctypes.c_size_t(csr.shape[1]),
ctypes.byref(handle)))
_check_call(_LIB.XGDMatrixCreateFromCSREx(
c_array(ctypes.c_size_t, csr.indptr),
c_array(ctypes.c_uint, csr.indices),
c_array(ctypes.c_float, csr.data),
ctypes.c_size_t(len(csr.indptr)),
ctypes.c_size_t(len(csr.data)),
ctypes.c_size_t(csr.shape[1]),
ctypes.byref(handle)))
self.handle = handle
def _init_from_csc(self, csc):
"""Initialize data from a CSC matrix."""
if len(csc.indices) != len(csc.data):
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
raise ValueError('length mismatch: {} vs {}'.format(
len(csc.indices), len(csc.data)))
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr),
c_array(ctypes.c_uint, csc.indices),
c_array(ctypes.c_float, csc.data),
ctypes.c_size_t(len(csc.indptr)),
ctypes.c_size_t(len(csc.data)),
ctypes.c_size_t(csc.shape[0]),
ctypes.byref(handle)))
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
c_array(ctypes.c_size_t, csc.indptr),
c_array(ctypes.c_uint, csc.indices),
c_array(ctypes.c_float, csc.data),
ctypes.c_size_t(len(csc.indptr)),
ctypes.c_size_t(len(csc.data)),
ctypes.c_size_t(csc.shape[0]),
ctypes.byref(handle)))
self.handle = handle
def _init_from_npy2d(self, mat, missing, nthread):
"""Initialize data from a 2-D numpy matrix.
If ``mat`` does not have ``order='C'`` (aka row-major) or is not contiguous,
a temporary copy will be made.
If ``mat`` does not have ``order='C'`` (aka row-major) or is
not contiguous, a temporary copy will be made.
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will be made.
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
be made.
So there could be as many as two temporary data copies; be mindful of
input layout and type if memory use is a concern.
So there could be as many as two temporary data copies; be mindful of input layout
and type if memory use is a concern.
"""
if len(mat.shape) != 2:
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
@ -536,21 +548,14 @@ class DMatrix(object):
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
if nthread is None:
_check_call(_LIB.XGDMatrixCreateFromMat(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing),
ctypes.byref(handle)))
else:
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing),
ctypes.byref(handle),
nthread))
nthread = nthread if nthread is not None else 1
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing),
ctypes.byref(handle),
c_bst_ulong(nthread)))
self.handle = handle
def _init_from_dt(self, data, nthread):
@ -572,7 +577,8 @@ class DMatrix(object):
# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()
for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
feature_type_strings[icol] = ctypes.c_char_p(
data.stypes[icol].name.encode('utf-8'))
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT(
@ -598,7 +604,8 @@ class DMatrix(object):
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle)))
ctypes.c_float(missing), ctypes.c_int(nthread),
ctypes.byref(handle)))
self.handle = handle
def _init_from_array_interface(self, data, missing, nthread):
@ -614,7 +621,8 @@ class DMatrix(object):
_check_call(
_LIB.XGDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle)))
ctypes.c_float(missing), ctypes.c_int(nthread),
ctypes.byref(handle)))
self.handle = handle
def __del__(self):
@ -675,6 +683,7 @@ class DMatrix(object):
data: numpy array
The array of data to be set
"""
data, _, _ = _convert_dataframes(data, None, None, field, 'float')
if isinstance(data, np.ndarray):
self.set_float_info_npy2d(field, data)
return
@ -684,20 +693,6 @@ class DMatrix(object):
c_data,
c_bst_ulong(len(data))))
def set_interface_info(self, field, data):
"""Set info type property into DMatrix."""
# If we are passed a dataframe, extract the series
if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
if len(data.columns) != 1:
raise ValueError('Expecting meta-info to contain a single column')
data = data[data.columns[0]]
interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), 'utf-8')
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
c_str(field),
interface))
def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix
for numpy 2d array input
@ -710,16 +705,7 @@ class DMatrix(object):
data: numpy array
The array of data to be set
"""
try:
if not data.flags.c_contiguous:
warnings.warn("Use subset (sliced data) of np.ndarray is not recommended " +
"because it will generate extra copies and increase " +
"memory consumption")
data = np.array(data, copy=True, dtype=np.float32)
else:
data = np.array(data, copy=False, dtype=np.float32)
except AttributeError:
data = np.array(data, copy=False, dtype=np.float32)
data = _maybe_np_slice(data, np.float32)
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
@ -737,21 +723,29 @@ class DMatrix(object):
data: numpy array
The array of data to be set
"""
try:
if not data.flags.c_contiguous:
warnings.warn("Use subset (sliced data) of np.ndarray is not recommended " +
"because it will generate extra copies and increase " +
"memory consumption")
data = np.array(data, copy=True, dtype=ctypes.c_uint)
else:
data = np.array(data, copy=False, dtype=ctypes.c_uint)
except AttributeError:
data = np.array(data, copy=False, dtype=ctypes.c_uint)
data = _maybe_np_slice(data, np.uint32)
data, _, _ = _convert_dataframes(data, None, None, field, 'uint32')
data = np.array(data, copy=False, dtype=ctypes.c_uint)
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
def set_interface_info(self, field, data):
"""Set info type property into DMatrix."""
# If we are passed a dataframe, extract the series
if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
if len(data.columns) != 1:
raise ValueError(
'Expecting meta-info to contain a single column')
data = data[data.columns[0]]
interface = bytes(json.dumps([data.__cuda_array_interface__],
indent=2), 'utf-8')
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
c_str(field),
interface))
def save_binary(self, fname, silent=True):
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
by providing the path to :py:func:`xgboost.DMatrix` as input.
@ -775,26 +769,13 @@ class DMatrix(object):
label: array like
The label information to be set into DMatrix
"""
if isinstance(label, np.ndarray):
self.set_label_npy2d(label)
elif _has_cuda_array_interface(label):
if _has_cuda_array_interface(label):
self.set_interface_info('label', label)
else:
self.set_float_info('label', label)
def set_label_npy2d(self, label):
"""Set label of dmatrix
Parameters
----------
label: array like
The label information to be set into DMatrix
from numpy 2D array
"""
self.set_float_info_npy2d('label', label)
def set_weight(self, weight):
""" Set weight of each instance.
"""Set weight of each instance.
Parameters
----------
@ -803,49 +784,30 @@ class DMatrix(object):
.. note:: For ranking task, weights are per-group.
In ranking task, one weight is assigned to each group (not each data
point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign
weights to individual data points.
In ranking task, one weight is assigned to each group (not each
data point). This is because we only care about the relative
ordering of data points within each group, so it doesn't make
sense to assign weights to individual data points.
"""
if isinstance(weight, np.ndarray):
self.set_weight_npy2d(weight)
elif _has_cuda_array_interface(weight):
if _has_cuda_array_interface(weight):
self.set_interface_info('weight', weight)
else:
self.set_float_info('weight', weight)
def set_weight_npy2d(self, weight):
""" Set weight of each instance
for numpy 2D array
Parameters
----------
weight : array like
Weight for each data point in numpy 2D array
.. note:: For ranking task, weights are per-group.
In ranking task, one weight is assigned to each group (not each data
point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign
weights to individual data points.
"""
self.set_float_info_npy2d('weight', weight)
def set_base_margin(self, margin):
""" Set base margin of booster to start from.
"""Set base margin of booster to start from.
This can be used to specify a prediction value of
existing model to be base_margin
However, remember margin is needed, instead of transformed prediction
e.g. for logistic regression: need to put in value before logistic transformation
see also example/demo.py
This can be used to specify a prediction value of existing model to be
base_margin However, remember margin is needed, instead of transformed
prediction e.g. for logistic regression: need to put in value before
logistic transformation see also example/demo.py
Parameters
----------
margin: array like
Prediction margin of each datapoint
"""
if _has_cuda_array_interface(margin):
self.set_interface_info('base_margin', margin)

View File

@ -920,7 +920,10 @@ class LambdaRankObj : public ObjFunction {
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(info.labels_.Size());
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
CHECK(gptr.size() != 0 && gptr.back() == info.labels_.Size())
<< "group structure not consistent with #rows";
<< "group structure not consistent with #rows" << ", "
<< "group ponter size: " << gptr.size() << ", "
<< "labels size: " << info.labels_.Size() << ", "
<< "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back());
#if defined(__CUDACC__)
// Check if we have a GPU assignment; else, revert back to CPU

View File

@ -1,51 +1,54 @@
# -*- coding: utf-8 -*-
import unittest
import pytest
import testing as tm
import xgboost as xgb
try:
import datatable as dt
import pandas as pd
except ImportError:
pass
pytestmark = pytest.mark.skipif(
tm.no_dt()['condition'] or tm.no_pandas()['condition'],
reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason'])
class TestDataTable(unittest.TestCase):
def test_dt(self):
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
labels = dt.Frame([1, 2])
dm = xgb.DMatrix(dtable, label=labels)
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'float', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
# overwrite feature_names
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
self.assertRaises(ValueError, xgb.DMatrix, dtable)
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
dtable = dt.Frame(df)
dm = xgb.DMatrix(dtable)
assert dm.feature_names == ['A=1', 'A=2']
assert dm.feature_types == ['int', 'int']
assert dm.num_row() == 3
assert dm.num_col() == 2
# -*- coding: utf-8 -*-
import unittest
import pytest
import numpy as np
import testing as tm
import xgboost as xgb
try:
import datatable as dt
import pandas as pd
except ImportError:
pass
pytestmark = pytest.mark.skipif(
tm.no_dt()['condition'] or tm.no_pandas()['condition'],
reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason'])
class TestDataTable(unittest.TestCase):
def test_dt(self):
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
labels = dt.Frame([1, 2])
dm = xgb.DMatrix(dtable, label=labels)
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'float', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
np.testing.assert_array_equal(np.array([1, 2]), dm.get_label())
# overwrite feature_names
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
self.assertRaises(ValueError, xgb.DMatrix, dtable)
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
dtable = dt.Frame(df)
dm = xgb.DMatrix(dtable)
assert dm.feature_names == ['A=1', 'A=2']
assert dm.feature_types == ['int', 'int']
assert dm.num_row() == 3
assert dm.num_col() == 2

View File

@ -29,6 +29,7 @@ class TestPandas(unittest.TestCase):
assert dm.feature_types == ['int', 'float', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
np.testing.assert_array_equal(dm.get_label(), np.array([1, 2]))
# overwrite feature_names and feature_types
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
@ -51,6 +52,7 @@ class TestPandas(unittest.TestCase):
assert dm.feature_types == ['int', 'float', 'i']
assert dm.num_row() == 2
assert dm.num_col() == 3
np.testing.assert_array_equal(dm.get_label(), np.array([1, 2]))
df = pd.DataFrame([[1, 2., 1], [2, 3., 1]], columns=[4, 5, 6])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
@ -110,21 +112,38 @@ class TestPandas(unittest.TestCase):
def test_pandas_label(self):
# label must be a single column
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df)
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
None, None, 'label', 'float')
# label must be supported dtype
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df)
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
None, None, 'label', 'float')
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
result = xgb.core._maybe_pandas_label(df)
result, _, _ = xgb.core._maybe_pandas_data(df, None, None,
'label', 'float')
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
dtype=float))
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
assert dm.num_row() == 3
assert dm.num_col() == 2
def test_pandas_weight(self):
kRows = 32
kCols = 8
X = np.random.randn(kRows, kCols)
y = np.random.randn(kRows)
w = np.random.randn(kRows).astype(np.float32)
w_pd = pd.DataFrame(w)
data = xgb.DMatrix(X, y, w_pd)
assert data.num_row() == kRows
assert data.num_col() == kCols
np.testing.assert_array_equal(data.get_weight(), w)
def test_cv_as_pandas(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,

View File

@ -97,6 +97,7 @@ def test_ranking():
valid_data = xgb.DMatrix(x_valid, y_valid)
test_data = xgb.DMatrix(x_test)
train_data.set_group(train_group)
assert train_data.get_label().shape[0] == x_train.shape[0]
valid_data.set_group(valid_group)
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',