Fix histogram truncation. (#7181)

* Fix truncation.

* Lint.

* lint.
This commit is contained in:
Jiaming Yuan 2021-08-25 09:34:32 +08:00 committed by GitHub
parent 3290a4f3ed
commit ee8d1f5ed8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 13 additions and 11 deletions

View File

@ -302,7 +302,7 @@ if __name__ == '__main__':
with open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8') as fd:
description = fd.read()
with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION')) as fd:
with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION'), encoding="ascii") as fd:
version = fd.read().strip()
setup(name='xgboost',

View File

@ -22,7 +22,7 @@ except ImportError:
pass
VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
with open(VERSION_FILE) as f:
with open(VERSION_FILE, encoding="ascii") as f:
__version__ = f.read().strip()
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster', 'DataIter',

View File

@ -70,7 +70,7 @@ try:
'''Label encoder with JSON serialization methods.'''
def to_json(self):
'''Returns a JSON compatible dictionary'''
meta = dict()
meta = {}
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
meta[k] = v.tolist()
@ -81,7 +81,7 @@ try:
def from_json(self, doc):
# pylint: disable=attribute-defined-outside-init
'''Load the encoder back from a JSON compatible dict.'''
meta = dict()
meta = {}
for k, v in doc.items():
if k == 'classes_':
self.classes_ = np.array(v)

View File

@ -2197,7 +2197,8 @@ class Booster(object):
"""
if isinstance(fout, (STRING_TYPES, os.PathLike)):
fout = os.fspath(os.path.expanduser(fout))
fout = open(fout, 'w') # pylint: disable=consider-using-with
# pylint: disable=consider-using-with
fout = open(fout, 'w', encoding="utf-8")
need_close = True
else:
need_close = False

View File

@ -538,7 +538,7 @@ class XGBModel(XGBModelBase):
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder',
"enable_categorical"
}
filtered = dict()
filtered = {}
for k, v in params.items():
if k not in wrapper_specific and not callable(v):
filtered[k] = v
@ -557,7 +557,7 @@ class XGBModel(XGBModelBase):
return self._estimator_type # pylint: disable=no-member
def save_model(self, fname: Union[str, os.PathLike]) -> None:
meta = dict()
meta = {}
for k, v in self.__dict__.items():
if k == '_le':
meta['_le'] = self._le.to_json()
@ -596,7 +596,7 @@ class XGBModel(XGBModelBase):
)
return
meta = json.loads(meta_str)
states = dict()
states = {}
for k, v in meta.items():
if k == '_le':
self._le = XGBoostLabelEncoder()

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2020 by XGBoost Contributors
* Copyright 2020-2021 by XGBoost Contributors
*/
#ifndef HISTOGRAM_CUH_
#define HISTOGRAM_CUH_
@ -15,8 +15,9 @@ namespace tree {
template <typename GradientSumT>
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);
template <typename T>
XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) {
template <typename T, typename U>
XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, U const x) {
static_assert(sizeof(T) >= sizeof(U), "Rounding must have higher or equal precision.");
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
}