Fix metainfo from DataFrame. (#5216)
* Fix metainfo from DataFrame. * Unify helper functions for data and meta.
This commit is contained in:
parent
5d4c24a1fc
commit
1891cc766d
@ -231,10 +231,11 @@ def c_array(ctype, values):
|
|||||||
return (ctype * len(values))(*values)
|
return (ctype * len(values))(*values)
|
||||||
|
|
||||||
|
|
||||||
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
|
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64':
|
||||||
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
|
'int', 'uint8': 'int', 'uint16': 'int', 'uint32': 'int',
|
||||||
'float16': 'float', 'float32': 'float', 'float64': 'float',
|
'uint64': 'int', 'float16': 'float', 'float32': 'float',
|
||||||
'bool': 'i'}
|
'float64': 'float', 'bool': 'i'}
|
||||||
|
|
||||||
|
|
||||||
# Either object has cuda array interface or contains columns with interfaces
|
# Either object has cuda array interface or contains columns with interfaces
|
||||||
def _has_cuda_array_interface(data):
|
def _has_cuda_array_interface(data):
|
||||||
@ -242,7 +243,8 @@ def _has_cuda_array_interface(data):
|
|||||||
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
|
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
|
||||||
|
|
||||||
|
|
||||||
def _maybe_pandas_data(data, feature_names, feature_types):
|
def _maybe_pandas_data(data, feature_names, feature_types,
|
||||||
|
meta=None, meta_type=None):
|
||||||
"""Extract internal data from pd.DataFrame for DMatrix data"""
|
"""Extract internal data from pd.DataFrame for DMatrix data"""
|
||||||
|
|
||||||
if not (PANDAS_INSTALLED and isinstance(data, DataFrame)):
|
if not (PANDAS_INSTALLED and isinstance(data, DataFrame)):
|
||||||
@ -250,51 +252,41 @@ def _maybe_pandas_data(data, feature_names, feature_types):
|
|||||||
|
|
||||||
data_dtypes = data.dtypes
|
data_dtypes = data.dtypes
|
||||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
|
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
|
||||||
bad_fields = [str(data.columns[i]) for i, dtype in
|
bad_fields = [
|
||||||
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
|
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
|
||||||
|
if dtype.name not in PANDAS_DTYPE_MAPPER
|
||||||
|
]
|
||||||
|
|
||||||
msg = """DataFrame.dtypes for data must be int, float or bool.
|
msg = """DataFrame.dtypes for data must be int, float or bool.
|
||||||
Did not expect the data types in fields """
|
Did not expect the data types in fields """
|
||||||
raise ValueError(msg + ', '.join(bad_fields))
|
raise ValueError(msg + ', '.join(bad_fields))
|
||||||
|
|
||||||
if feature_names is None:
|
if feature_names is None and meta is None:
|
||||||
if isinstance(data.columns, MultiIndex):
|
if isinstance(data.columns, MultiIndex):
|
||||||
feature_names = [
|
feature_names = [
|
||||||
' '.join([str(x) for x in i])
|
' '.join([str(x) for x in i]) for i in data.columns
|
||||||
for i in data.columns
|
|
||||||
]
|
]
|
||||||
elif isinstance(data.columns, Int64Index):
|
elif isinstance(data.columns, Int64Index):
|
||||||
feature_names = list(map(str, data.columns))
|
feature_names = list(map(str, data.columns))
|
||||||
else:
|
else:
|
||||||
feature_names = data.columns.format()
|
feature_names = data.columns.format()
|
||||||
|
|
||||||
if feature_types is None:
|
if feature_types is None and meta is None:
|
||||||
feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes]
|
feature_types = [
|
||||||
|
PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes
|
||||||
|
]
|
||||||
|
|
||||||
data = data.values.astype('float')
|
if meta and len(data.columns) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
'DataFrame for {meta} cannot have multiple columns'.format(
|
||||||
|
meta=meta))
|
||||||
|
|
||||||
|
dtype = meta_type if meta_type else 'float'
|
||||||
|
data = data.values.astype(dtype)
|
||||||
|
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
def _maybe_pandas_label(label):
|
|
||||||
"""Extract internal data from pd.DataFrame for DMatrix label."""
|
|
||||||
|
|
||||||
if PANDAS_INSTALLED and isinstance(label, DataFrame):
|
|
||||||
if len(label.columns) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
'DataFrame for label cannot have multiple columns')
|
|
||||||
|
|
||||||
label_dtypes = label.dtypes
|
|
||||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER
|
|
||||||
for dtype in label_dtypes):
|
|
||||||
raise ValueError(
|
|
||||||
'DataFrame.dtypes for label must be int, float or bool')
|
|
||||||
label = label.values.astype('float')
|
|
||||||
# pd.Series can be passed to xgb as it is
|
|
||||||
|
|
||||||
return label
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_cudf_dataframe(data, feature_names, feature_types):
|
def _maybe_cudf_dataframe(data, feature_names, feature_types):
|
||||||
"""Extract internal data from cudf.DataFrame for DMatrix data."""
|
"""Extract internal data from cudf.DataFrame for DMatrix data."""
|
||||||
if not (CUDF_INSTALLED and isinstance(data,
|
if not (CUDF_INSTALLED and isinstance(data,
|
||||||
@ -324,11 +316,21 @@ DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
|
|||||||
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
||||||
|
|
||||||
|
|
||||||
def _maybe_dt_data(data, feature_names, feature_types):
|
def _maybe_dt_data(data, feature_names, feature_types,
|
||||||
|
meta=None, meta_type=None):
|
||||||
"""Validate feature names and types if data table"""
|
"""Validate feature names and types if data table"""
|
||||||
if not isinstance(data, DataTable):
|
if not isinstance(data, DataTable):
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
if meta and data.shape[1] > 1:
|
||||||
|
raise ValueError(
|
||||||
|
'DataTable for label or weight cannot have multiple columns')
|
||||||
|
if meta:
|
||||||
|
# below requires new dt version
|
||||||
|
# extract first column
|
||||||
|
data = data.to_numpy()[:, 0].astype(meta_type)
|
||||||
|
return data, None, None
|
||||||
|
|
||||||
data_types_names = tuple(lt.name for lt in data.ltypes)
|
data_types_names = tuple(lt.name for lt in data.ltypes)
|
||||||
bad_fields = [data.names[i]
|
bad_fields = [data.names[i]
|
||||||
for i, type_name in enumerate(data_types_names)
|
for i, type_name in enumerate(data_types_names)
|
||||||
@ -338,40 +340,31 @@ def _maybe_dt_data(data, feature_names, feature_types):
|
|||||||
Did not expect the data types in fields """
|
Did not expect the data types in fields """
|
||||||
raise ValueError(msg + ', '.join(bad_fields))
|
raise ValueError(msg + ', '.join(bad_fields))
|
||||||
|
|
||||||
if feature_names is None:
|
if feature_names is None and meta is None:
|
||||||
feature_names = data.names
|
feature_names = data.names
|
||||||
|
|
||||||
# always return stypes for dt ingestion
|
# always return stypes for dt ingestion
|
||||||
if feature_types is not None:
|
if feature_types is not None:
|
||||||
raise ValueError('DataTable has own feature types, cannot pass them in')
|
raise ValueError(
|
||||||
|
'DataTable has own feature types, cannot pass them in.')
|
||||||
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
|
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
|
||||||
|
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
def _maybe_dt_array(array):
|
def _convert_dataframes(data, feature_names, feature_types,
|
||||||
"""Extract numpy array from single column data table"""
|
meta=None, meta_type=None):
|
||||||
if not isinstance(array, DataTable) or array is None:
|
|
||||||
return array
|
|
||||||
|
|
||||||
if array.shape[1] > 1:
|
|
||||||
raise ValueError('DataTable for label or weight cannot have multiple columns')
|
|
||||||
|
|
||||||
# below requires new dt version
|
|
||||||
# extract first column
|
|
||||||
array = array.to_numpy()[:, 0].astype('float')
|
|
||||||
|
|
||||||
return array
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_dataframes(data, feature_names, feature_types):
|
|
||||||
data, feature_names, feature_types = _maybe_pandas_data(data,
|
data, feature_names, feature_types = _maybe_pandas_data(data,
|
||||||
feature_names,
|
feature_names,
|
||||||
feature_types)
|
feature_types,
|
||||||
|
meta,
|
||||||
|
meta_type)
|
||||||
|
|
||||||
data, feature_names, feature_types = _maybe_dt_data(data,
|
data, feature_names, feature_types = _maybe_dt_data(data,
|
||||||
feature_names,
|
feature_names,
|
||||||
feature_types)
|
feature_types,
|
||||||
|
meta,
|
||||||
|
meta_type)
|
||||||
|
|
||||||
data, feature_names, feature_types = _maybe_cudf_dataframe(
|
data, feature_names, feature_types = _maybe_cudf_dataframe(
|
||||||
data, feature_names, feature_types)
|
data, feature_names, feature_types)
|
||||||
@ -379,6 +372,23 @@ def _convert_dataframes(data, feature_names, feature_types):
|
|||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_np_slice(data, dtype=np.float32):
|
||||||
|
'''Handle numpy slice. This can be removed if we use __array_interface__.
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if not data.flags.c_contiguous:
|
||||||
|
warnings.warn(
|
||||||
|
"Use subset (sliced data) of np.ndarray is not recommended " +
|
||||||
|
"because it will generate extra copies and increase " +
|
||||||
|
"memory consumption")
|
||||||
|
data = np.array(data, copy=True, dtype=dtype)
|
||||||
|
else:
|
||||||
|
data = np.array(data, copy=False, dtype=dtype)
|
||||||
|
except AttributeError:
|
||||||
|
data = np.array(data, copy=False, dtype=dtype)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class DMatrix(object):
|
class DMatrix(object):
|
||||||
"""Data Matrix used in XGBoost.
|
"""Data Matrix used in XGBoost.
|
||||||
|
|
||||||
@ -415,10 +425,10 @@ class DMatrix(object):
|
|||||||
|
|
||||||
.. note:: For ranking task, weights are per-group.
|
.. note:: For ranking task, weights are per-group.
|
||||||
|
|
||||||
In ranking task, one weight is assigned to each group (not each data
|
In ranking task, one weight is assigned to each group (not each
|
||||||
point). This is because we only care about the relative ordering of
|
data point). This is because we only care about the relative
|
||||||
data points within each group, so it doesn't make sense to assign
|
ordering of data points within each group, so it doesn't make
|
||||||
weights to individual data points.
|
sense to assign weights to individual data points.
|
||||||
|
|
||||||
silent : boolean, optional
|
silent : boolean, optional
|
||||||
Whether print messages during construction
|
Whether print messages during construction
|
||||||
@ -429,6 +439,7 @@ class DMatrix(object):
|
|||||||
nthread : integer, optional
|
nthread : integer, optional
|
||||||
Number of threads to use for loading data from numpy array. If -1,
|
Number of threads to use for loading data from numpy array. If -1,
|
||||||
uses maximum threads available on the system.
|
uses maximum threads available on the system.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# force into void_p, mac need to pass things in as void_p
|
# force into void_p, mac need to pass things in as void_p
|
||||||
if data is None:
|
if data is None:
|
||||||
@ -447,11 +458,6 @@ class DMatrix(object):
|
|||||||
data, feature_names, feature_types
|
data, feature_names, feature_types
|
||||||
)
|
)
|
||||||
|
|
||||||
label = _maybe_pandas_label(label)
|
|
||||||
label = _maybe_dt_array(label)
|
|
||||||
weight = _maybe_dt_array(weight)
|
|
||||||
base_margin = _maybe_dt_array(base_margin)
|
|
||||||
|
|
||||||
if isinstance(data, (STRING_TYPES, os_PathLike)):
|
if isinstance(data, (STRING_TYPES, os_PathLike)):
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
|
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
|
||||||
@ -491,9 +497,11 @@ class DMatrix(object):
|
|||||||
def _init_from_csr(self, csr):
|
def _init_from_csr(self, csr):
|
||||||
"""Initialize data from a CSR matrix."""
|
"""Initialize data from a CSR matrix."""
|
||||||
if len(csr.indices) != len(csr.data):
|
if len(csr.indices) != len(csr.data):
|
||||||
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
|
raise ValueError('length mismatch: {} vs {}'.format(
|
||||||
|
len(csr.indices), len(csr.data)))
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr),
|
_check_call(_LIB.XGDMatrixCreateFromCSREx(
|
||||||
|
c_array(ctypes.c_size_t, csr.indptr),
|
||||||
c_array(ctypes.c_uint, csr.indices),
|
c_array(ctypes.c_uint, csr.indices),
|
||||||
c_array(ctypes.c_float, csr.data),
|
c_array(ctypes.c_float, csr.data),
|
||||||
ctypes.c_size_t(len(csr.indptr)),
|
ctypes.c_size_t(len(csr.indptr)),
|
||||||
@ -505,9 +513,11 @@ class DMatrix(object):
|
|||||||
def _init_from_csc(self, csc):
|
def _init_from_csc(self, csc):
|
||||||
"""Initialize data from a CSC matrix."""
|
"""Initialize data from a CSC matrix."""
|
||||||
if len(csc.indices) != len(csc.data):
|
if len(csc.indices) != len(csc.data):
|
||||||
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
|
raise ValueError('length mismatch: {} vs {}'.format(
|
||||||
|
len(csc.indices), len(csc.data)))
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr),
|
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
|
||||||
|
c_array(ctypes.c_size_t, csc.indptr),
|
||||||
c_array(ctypes.c_uint, csc.indices),
|
c_array(ctypes.c_uint, csc.indices),
|
||||||
c_array(ctypes.c_float, csc.data),
|
c_array(ctypes.c_float, csc.data),
|
||||||
ctypes.c_size_t(len(csc.indptr)),
|
ctypes.c_size_t(len(csc.indptr)),
|
||||||
@ -519,13 +529,15 @@ class DMatrix(object):
|
|||||||
def _init_from_npy2d(self, mat, missing, nthread):
|
def _init_from_npy2d(self, mat, missing, nthread):
|
||||||
"""Initialize data from a 2-D numpy matrix.
|
"""Initialize data from a 2-D numpy matrix.
|
||||||
|
|
||||||
If ``mat`` does not have ``order='C'`` (aka row-major) or is not contiguous,
|
If ``mat`` does not have ``order='C'`` (aka row-major) or is
|
||||||
a temporary copy will be made.
|
not contiguous, a temporary copy will be made.
|
||||||
|
|
||||||
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will be made.
|
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
|
||||||
|
be made.
|
||||||
|
|
||||||
|
So there could be as many as two temporary data copies; be mindful of
|
||||||
|
input layout and type if memory use is a concern.
|
||||||
|
|
||||||
So there could be as many as two temporary data copies; be mindful of input layout
|
|
||||||
and type if memory use is a concern.
|
|
||||||
"""
|
"""
|
||||||
if len(mat.shape) != 2:
|
if len(mat.shape) != 2:
|
||||||
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
|
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
|
||||||
@ -536,21 +548,14 @@ class DMatrix(object):
|
|||||||
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
missing = missing if missing is not None else np.nan
|
missing = missing if missing is not None else np.nan
|
||||||
if nthread is None:
|
nthread = nthread if nthread is not None else 1
|
||||||
_check_call(_LIB.XGDMatrixCreateFromMat(
|
|
||||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
|
||||||
c_bst_ulong(mat.shape[0]),
|
|
||||||
c_bst_ulong(mat.shape[1]),
|
|
||||||
ctypes.c_float(missing),
|
|
||||||
ctypes.byref(handle)))
|
|
||||||
else:
|
|
||||||
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
||||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||||
c_bst_ulong(mat.shape[0]),
|
c_bst_ulong(mat.shape[0]),
|
||||||
c_bst_ulong(mat.shape[1]),
|
c_bst_ulong(mat.shape[1]),
|
||||||
ctypes.c_float(missing),
|
ctypes.c_float(missing),
|
||||||
ctypes.byref(handle),
|
ctypes.byref(handle),
|
||||||
nthread))
|
c_bst_ulong(nthread)))
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
|
|
||||||
def _init_from_dt(self, data, nthread):
|
def _init_from_dt(self, data, nthread):
|
||||||
@ -572,7 +577,8 @@ class DMatrix(object):
|
|||||||
# always return stypes for dt ingestion
|
# always return stypes for dt ingestion
|
||||||
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
||||||
for icol in range(data.ncols):
|
for icol in range(data.ncols):
|
||||||
feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8'))
|
feature_type_strings[icol] = ctypes.c_char_p(
|
||||||
|
data.stypes[icol].name.encode('utf-8'))
|
||||||
|
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(_LIB.XGDMatrixCreateFromDT(
|
_check_call(_LIB.XGDMatrixCreateFromDT(
|
||||||
@ -598,7 +604,8 @@ class DMatrix(object):
|
|||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||||
interfaces_str,
|
interfaces_str,
|
||||||
ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle)))
|
ctypes.c_float(missing), ctypes.c_int(nthread),
|
||||||
|
ctypes.byref(handle)))
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
|
|
||||||
def _init_from_array_interface(self, data, missing, nthread):
|
def _init_from_array_interface(self, data, missing, nthread):
|
||||||
@ -614,7 +621,8 @@ class DMatrix(object):
|
|||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGDMatrixCreateFromArrayInterface(
|
_LIB.XGDMatrixCreateFromArrayInterface(
|
||||||
interface_str,
|
interface_str,
|
||||||
ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle)))
|
ctypes.c_float(missing), ctypes.c_int(nthread),
|
||||||
|
ctypes.byref(handle)))
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
@ -675,6 +683,7 @@ class DMatrix(object):
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
|
data, _, _ = _convert_dataframes(data, None, None, field, 'float')
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
self.set_float_info_npy2d(field, data)
|
self.set_float_info_npy2d(field, data)
|
||||||
return
|
return
|
||||||
@ -684,20 +693,6 @@ class DMatrix(object):
|
|||||||
c_data,
|
c_data,
|
||||||
c_bst_ulong(len(data))))
|
c_bst_ulong(len(data))))
|
||||||
|
|
||||||
def set_interface_info(self, field, data):
|
|
||||||
"""Set info type property into DMatrix."""
|
|
||||||
|
|
||||||
# If we are passed a dataframe, extract the series
|
|
||||||
if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
|
|
||||||
if len(data.columns) != 1:
|
|
||||||
raise ValueError('Expecting meta-info to contain a single column')
|
|
||||||
data = data[data.columns[0]]
|
|
||||||
|
|
||||||
interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), 'utf-8')
|
|
||||||
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
|
|
||||||
c_str(field),
|
|
||||||
interface))
|
|
||||||
|
|
||||||
def set_float_info_npy2d(self, field, data):
|
def set_float_info_npy2d(self, field, data):
|
||||||
"""Set float type property into the DMatrix
|
"""Set float type property into the DMatrix
|
||||||
for numpy 2d array input
|
for numpy 2d array input
|
||||||
@ -710,16 +705,7 @@ class DMatrix(object):
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
try:
|
data = _maybe_np_slice(data, np.float32)
|
||||||
if not data.flags.c_contiguous:
|
|
||||||
warnings.warn("Use subset (sliced data) of np.ndarray is not recommended " +
|
|
||||||
"because it will generate extra copies and increase " +
|
|
||||||
"memory consumption")
|
|
||||||
data = np.array(data, copy=True, dtype=np.float32)
|
|
||||||
else:
|
|
||||||
data = np.array(data, copy=False, dtype=np.float32)
|
|
||||||
except AttributeError:
|
|
||||||
data = np.array(data, copy=False, dtype=np.float32)
|
|
||||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||||
c_str(field),
|
c_str(field),
|
||||||
@ -737,21 +723,29 @@ class DMatrix(object):
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
try:
|
data = _maybe_np_slice(data, np.uint32)
|
||||||
if not data.flags.c_contiguous:
|
data, _, _ = _convert_dataframes(data, None, None, field, 'uint32')
|
||||||
warnings.warn("Use subset (sliced data) of np.ndarray is not recommended " +
|
|
||||||
"because it will generate extra copies and increase " +
|
|
||||||
"memory consumption")
|
|
||||||
data = np.array(data, copy=True, dtype=ctypes.c_uint)
|
|
||||||
else:
|
|
||||||
data = np.array(data, copy=False, dtype=ctypes.c_uint)
|
|
||||||
except AttributeError:
|
|
||||||
data = np.array(data, copy=False, dtype=ctypes.c_uint)
|
data = np.array(data, copy=False, dtype=ctypes.c_uint)
|
||||||
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
||||||
c_str(field),
|
c_str(field),
|
||||||
c_array(ctypes.c_uint, data),
|
c_array(ctypes.c_uint, data),
|
||||||
c_bst_ulong(len(data))))
|
c_bst_ulong(len(data))))
|
||||||
|
|
||||||
|
def set_interface_info(self, field, data):
|
||||||
|
"""Set info type property into DMatrix."""
|
||||||
|
# If we are passed a dataframe, extract the series
|
||||||
|
if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
|
||||||
|
if len(data.columns) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
'Expecting meta-info to contain a single column')
|
||||||
|
data = data[data.columns[0]]
|
||||||
|
|
||||||
|
interface = bytes(json.dumps([data.__cuda_array_interface__],
|
||||||
|
indent=2), 'utf-8')
|
||||||
|
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
|
||||||
|
c_str(field),
|
||||||
|
interface))
|
||||||
|
|
||||||
def save_binary(self, fname, silent=True):
|
def save_binary(self, fname, silent=True):
|
||||||
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
||||||
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
by providing the path to :py:func:`xgboost.DMatrix` as input.
|
||||||
@ -775,26 +769,13 @@ class DMatrix(object):
|
|||||||
label: array like
|
label: array like
|
||||||
The label information to be set into DMatrix
|
The label information to be set into DMatrix
|
||||||
"""
|
"""
|
||||||
if isinstance(label, np.ndarray):
|
if _has_cuda_array_interface(label):
|
||||||
self.set_label_npy2d(label)
|
|
||||||
elif _has_cuda_array_interface(label):
|
|
||||||
self.set_interface_info('label', label)
|
self.set_interface_info('label', label)
|
||||||
else:
|
else:
|
||||||
self.set_float_info('label', label)
|
self.set_float_info('label', label)
|
||||||
|
|
||||||
def set_label_npy2d(self, label):
|
|
||||||
"""Set label of dmatrix
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
label: array like
|
|
||||||
The label information to be set into DMatrix
|
|
||||||
from numpy 2D array
|
|
||||||
"""
|
|
||||||
self.set_float_info_npy2d('label', label)
|
|
||||||
|
|
||||||
def set_weight(self, weight):
|
def set_weight(self, weight):
|
||||||
""" Set weight of each instance.
|
"""Set weight of each instance.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -803,49 +784,30 @@ class DMatrix(object):
|
|||||||
|
|
||||||
.. note:: For ranking task, weights are per-group.
|
.. note:: For ranking task, weights are per-group.
|
||||||
|
|
||||||
In ranking task, one weight is assigned to each group (not each data
|
In ranking task, one weight is assigned to each group (not each
|
||||||
point). This is because we only care about the relative ordering of
|
data point). This is because we only care about the relative
|
||||||
data points within each group, so it doesn't make sense to assign
|
ordering of data points within each group, so it doesn't make
|
||||||
weights to individual data points.
|
sense to assign weights to individual data points.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(weight, np.ndarray):
|
if _has_cuda_array_interface(weight):
|
||||||
self.set_weight_npy2d(weight)
|
|
||||||
elif _has_cuda_array_interface(weight):
|
|
||||||
self.set_interface_info('weight', weight)
|
self.set_interface_info('weight', weight)
|
||||||
else:
|
else:
|
||||||
self.set_float_info('weight', weight)
|
self.set_float_info('weight', weight)
|
||||||
|
|
||||||
def set_weight_npy2d(self, weight):
|
|
||||||
""" Set weight of each instance
|
|
||||||
for numpy 2D array
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
weight : array like
|
|
||||||
Weight for each data point in numpy 2D array
|
|
||||||
|
|
||||||
.. note:: For ranking task, weights are per-group.
|
|
||||||
|
|
||||||
In ranking task, one weight is assigned to each group (not each data
|
|
||||||
point). This is because we only care about the relative ordering of
|
|
||||||
data points within each group, so it doesn't make sense to assign
|
|
||||||
weights to individual data points.
|
|
||||||
"""
|
|
||||||
self.set_float_info_npy2d('weight', weight)
|
|
||||||
|
|
||||||
def set_base_margin(self, margin):
|
def set_base_margin(self, margin):
|
||||||
""" Set base margin of booster to start from.
|
"""Set base margin of booster to start from.
|
||||||
|
|
||||||
This can be used to specify a prediction value of
|
This can be used to specify a prediction value of existing model to be
|
||||||
existing model to be base_margin
|
base_margin However, remember margin is needed, instead of transformed
|
||||||
However, remember margin is needed, instead of transformed prediction
|
prediction e.g. for logistic regression: need to put in value before
|
||||||
e.g. for logistic regression: need to put in value before logistic transformation
|
logistic transformation see also example/demo.py
|
||||||
see also example/demo.py
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
margin: array like
|
margin: array like
|
||||||
Prediction margin of each datapoint
|
Prediction margin of each datapoint
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if _has_cuda_array_interface(margin):
|
if _has_cuda_array_interface(margin):
|
||||||
self.set_interface_info('base_margin', margin)
|
self.set_interface_info('base_margin', margin)
|
||||||
|
|||||||
@ -920,7 +920,10 @@ class LambdaRankObj : public ObjFunction {
|
|||||||
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(info.labels_.Size());
|
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(info.labels_.Size());
|
||||||
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
|
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
|
||||||
CHECK(gptr.size() != 0 && gptr.back() == info.labels_.Size())
|
CHECK(gptr.size() != 0 && gptr.back() == info.labels_.Size())
|
||||||
<< "group structure not consistent with #rows";
|
<< "group structure not consistent with #rows" << ", "
|
||||||
|
<< "group ponter size: " << gptr.size() << ", "
|
||||||
|
<< "labels size: " << info.labels_.Size() << ", "
|
||||||
|
<< "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back());
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
// Check if we have a GPU assignment; else, revert back to CPU
|
// Check if we have a GPU assignment; else, revert back to CPU
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import testing as tm
|
import testing as tm
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
@ -29,6 +30,8 @@ class TestDataTable(unittest.TestCase):
|
|||||||
assert dm.num_row() == 2
|
assert dm.num_row() == 2
|
||||||
assert dm.num_col() == 3
|
assert dm.num_col() == 3
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(np.array([1, 2]), dm.get_label())
|
||||||
|
|
||||||
# overwrite feature_names
|
# overwrite feature_names
|
||||||
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
|
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
|
||||||
feature_names=['x', 'y', 'z'])
|
feature_names=['x', 'y', 'z'])
|
||||||
|
|||||||
@ -29,6 +29,7 @@ class TestPandas(unittest.TestCase):
|
|||||||
assert dm.feature_types == ['int', 'float', 'i']
|
assert dm.feature_types == ['int', 'float', 'i']
|
||||||
assert dm.num_row() == 2
|
assert dm.num_row() == 2
|
||||||
assert dm.num_col() == 3
|
assert dm.num_col() == 3
|
||||||
|
np.testing.assert_array_equal(dm.get_label(), np.array([1, 2]))
|
||||||
|
|
||||||
# overwrite feature_names and feature_types
|
# overwrite feature_names and feature_types
|
||||||
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
|
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
|
||||||
@ -51,6 +52,7 @@ class TestPandas(unittest.TestCase):
|
|||||||
assert dm.feature_types == ['int', 'float', 'i']
|
assert dm.feature_types == ['int', 'float', 'i']
|
||||||
assert dm.num_row() == 2
|
assert dm.num_row() == 2
|
||||||
assert dm.num_col() == 3
|
assert dm.num_col() == 3
|
||||||
|
np.testing.assert_array_equal(dm.get_label(), np.array([1, 2]))
|
||||||
|
|
||||||
df = pd.DataFrame([[1, 2., 1], [2, 3., 1]], columns=[4, 5, 6])
|
df = pd.DataFrame([[1, 2., 1], [2, 3., 1]], columns=[4, 5, 6])
|
||||||
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
|
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
|
||||||
@ -110,21 +112,38 @@ class TestPandas(unittest.TestCase):
|
|||||||
def test_pandas_label(self):
|
def test_pandas_label(self):
|
||||||
# label must be a single column
|
# label must be a single column
|
||||||
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||||
self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df)
|
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
|
||||||
|
None, None, 'label', 'float')
|
||||||
|
|
||||||
# label must be supported dtype
|
# label must be supported dtype
|
||||||
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
||||||
self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df)
|
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
|
||||||
|
None, None, 'label', 'float')
|
||||||
|
|
||||||
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
||||||
result = xgb.core._maybe_pandas_label(df)
|
result, _, _ = xgb.core._maybe_pandas_data(df, None, None,
|
||||||
|
'label', 'float')
|
||||||
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
||||||
dtype=float))
|
dtype=float))
|
||||||
|
|
||||||
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
|
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
|
||||||
assert dm.num_row() == 3
|
assert dm.num_row() == 3
|
||||||
assert dm.num_col() == 2
|
assert dm.num_col() == 2
|
||||||
|
|
||||||
|
def test_pandas_weight(self):
|
||||||
|
kRows = 32
|
||||||
|
kCols = 8
|
||||||
|
|
||||||
|
X = np.random.randn(kRows, kCols)
|
||||||
|
y = np.random.randn(kRows)
|
||||||
|
w = np.random.randn(kRows).astype(np.float32)
|
||||||
|
w_pd = pd.DataFrame(w)
|
||||||
|
data = xgb.DMatrix(X, y, w_pd)
|
||||||
|
|
||||||
|
assert data.num_row() == kRows
|
||||||
|
assert data.num_col() == kCols
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(data.get_weight(), w)
|
||||||
|
|
||||||
def test_cv_as_pandas(self):
|
def test_cv_as_pandas(self):
|
||||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||||
|
|||||||
@ -97,6 +97,7 @@ def test_ranking():
|
|||||||
valid_data = xgb.DMatrix(x_valid, y_valid)
|
valid_data = xgb.DMatrix(x_valid, y_valid)
|
||||||
test_data = xgb.DMatrix(x_test)
|
test_data = xgb.DMatrix(x_test)
|
||||||
train_data.set_group(train_group)
|
train_data.set_group(train_group)
|
||||||
|
assert train_data.get_label().shape[0] == x_train.shape[0]
|
||||||
valid_data.set_group(valid_group)
|
valid_data.set_group(valid_group)
|
||||||
|
|
||||||
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',
|
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user