parent
3290a4f3ed
commit
ee8d1f5ed8
@ -302,7 +302,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
with open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8') as fd:
|
with open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8') as fd:
|
||||||
description = fd.read()
|
description = fd.read()
|
||||||
with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION')) as fd:
|
with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION'), encoding="ascii") as fd:
|
||||||
version = fd.read().strip()
|
version = fd.read().strip()
|
||||||
|
|
||||||
setup(name='xgboost',
|
setup(name='xgboost',
|
||||||
|
|||||||
@ -22,7 +22,7 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
|
VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
|
||||||
with open(VERSION_FILE) as f:
|
with open(VERSION_FILE, encoding="ascii") as f:
|
||||||
__version__ = f.read().strip()
|
__version__ = f.read().strip()
|
||||||
|
|
||||||
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster', 'DataIter',
|
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster', 'DataIter',
|
||||||
|
|||||||
@ -70,7 +70,7 @@ try:
|
|||||||
'''Label encoder with JSON serialization methods.'''
|
'''Label encoder with JSON serialization methods.'''
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
'''Returns a JSON compatible dictionary'''
|
'''Returns a JSON compatible dictionary'''
|
||||||
meta = dict()
|
meta = {}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if isinstance(v, np.ndarray):
|
if isinstance(v, np.ndarray):
|
||||||
meta[k] = v.tolist()
|
meta[k] = v.tolist()
|
||||||
@ -81,7 +81,7 @@ try:
|
|||||||
def from_json(self, doc):
|
def from_json(self, doc):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
'''Load the encoder back from a JSON compatible dict.'''
|
'''Load the encoder back from a JSON compatible dict.'''
|
||||||
meta = dict()
|
meta = {}
|
||||||
for k, v in doc.items():
|
for k, v in doc.items():
|
||||||
if k == 'classes_':
|
if k == 'classes_':
|
||||||
self.classes_ = np.array(v)
|
self.classes_ = np.array(v)
|
||||||
|
|||||||
@ -2197,7 +2197,8 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
if isinstance(fout, (STRING_TYPES, os.PathLike)):
|
if isinstance(fout, (STRING_TYPES, os.PathLike)):
|
||||||
fout = os.fspath(os.path.expanduser(fout))
|
fout = os.fspath(os.path.expanduser(fout))
|
||||||
fout = open(fout, 'w') # pylint: disable=consider-using-with
|
# pylint: disable=consider-using-with
|
||||||
|
fout = open(fout, 'w', encoding="utf-8")
|
||||||
need_close = True
|
need_close = True
|
||||||
else:
|
else:
|
||||||
need_close = False
|
need_close = False
|
||||||
|
|||||||
@ -538,7 +538,7 @@ class XGBModel(XGBModelBase):
|
|||||||
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder',
|
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder',
|
||||||
"enable_categorical"
|
"enable_categorical"
|
||||||
}
|
}
|
||||||
filtered = dict()
|
filtered = {}
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
if k not in wrapper_specific and not callable(v):
|
if k not in wrapper_specific and not callable(v):
|
||||||
filtered[k] = v
|
filtered[k] = v
|
||||||
@ -557,7 +557,7 @@ class XGBModel(XGBModelBase):
|
|||||||
return self._estimator_type # pylint: disable=no-member
|
return self._estimator_type # pylint: disable=no-member
|
||||||
|
|
||||||
def save_model(self, fname: Union[str, os.PathLike]) -> None:
|
def save_model(self, fname: Union[str, os.PathLike]) -> None:
|
||||||
meta = dict()
|
meta = {}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k == '_le':
|
if k == '_le':
|
||||||
meta['_le'] = self._le.to_json()
|
meta['_le'] = self._le.to_json()
|
||||||
@ -596,7 +596,7 @@ class XGBModel(XGBModelBase):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
meta = json.loads(meta_str)
|
meta = json.loads(meta_str)
|
||||||
states = dict()
|
states = {}
|
||||||
for k, v in meta.items():
|
for k, v in meta.items():
|
||||||
if k == '_le':
|
if k == '_le':
|
||||||
self._le = XGBoostLabelEncoder()
|
self._le = XGBoostLabelEncoder()
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2020 by XGBoost Contributors
|
* Copyright 2020-2021 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef HISTOGRAM_CUH_
|
#ifndef HISTOGRAM_CUH_
|
||||||
#define HISTOGRAM_CUH_
|
#define HISTOGRAM_CUH_
|
||||||
@ -15,8 +15,9 @@ namespace tree {
|
|||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename U>
|
||||||
XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) {
|
XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, U const x) {
|
||||||
|
static_assert(sizeof(T) >= sizeof(U), "Rounding must have higher or equal precision.");
|
||||||
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
|
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user