Rewrite Dask interface. (#4819)

This commit is contained in:
Jiaming Yuan
2019-09-25 01:30:14 -04:00
committed by GitHub
parent 562bb0ae31
commit b8433c455a
17 changed files with 1002 additions and 361 deletions

View File

@@ -11,9 +11,9 @@ import os
from .core import DMatrix, Booster
from .training import train, cv
from . import rabit # noqa
from . import dask # noqa
from . import tracker # noqa
from .tracker import RabitTracker # noqa
from . import dask
try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor
@@ -30,4 +30,4 @@ __all__ = ['DMatrix', 'Booster',
'RabitTracker',
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
'XGBRFClassifier', 'XGBRFRegressor',
'plot_importance', 'plot_tree', 'to_graphviz']
'plot_importance', 'plot_tree', 'to_graphviz', 'dask']

View File

@@ -96,14 +96,17 @@ except ImportError:
# pandas
try:
from pandas import DataFrame
from pandas import DataFrame, Series
from pandas import MultiIndex
from pandas import concat as pandas_concat
PANDAS_INSTALLED = True
except ImportError:
MultiIndex = object
DataFrame = object
Series = object
pandas_concat = None
PANDAS_INSTALLED = False
# dt
@@ -169,16 +172,35 @@ except ImportError:
# dask
try:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import Series as DaskSeries
from dask.array import Array as DaskArray
import dask
from dask import delayed
from dask import dataframe as dd
from dask import array as da
from dask.distributed import Client, get_client
from dask.distributed import comm as distributed_comm
from dask.distributed import wait as distributed_wait
from distributed import get_worker as distributed_get_worker
DASK_INSTALLED = True
except ImportError:
DaskDataFrame = object
DaskSeries = object
DaskArray = object
dd = None
da = None
Client = None
delayed = None
get_client = None
distributed_comm = None
distributed_wait = None
distributed_get_worker = None
dask = None
DASK_INSTALLED = False
try:
import sparse
import scipy.sparse as scipy_sparse
SCIPY_INSTALLED = True
except ImportError:
sparse = False
scipy_sparse = False
SCIPY_INSTALLED = False

View File

@@ -106,6 +106,28 @@ def from_cstr_to_pystr(data, length):
return res
def _expect(expectations, got):
'''Translate input error into string.
Parameters
----------
expectations: sequence
a list of expected value.
got:
actual input
Returns
-------
msg: str'''
msg = 'Expecting '
for t in range(len(expectations) - 1):
msg += str(expectations[t])
msg += ' or '
msg += str(expectations[-1])
msg += '. Got ' + str(got)
return msg
def _log_callback(msg):
"""Redirect logs from native library into Python console"""
print("{0:s}".format(py_str(msg)))
@@ -513,7 +535,8 @@ class DMatrix(object):
and type if memory use is a concern.
"""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
mat.shape)
# flatten the array by rows and ensure it is float32.
# we try to avoid data copies if possible (reshape returns a view when possible
# and we explicitly tell np.array to try and avoid copying)
@@ -1010,7 +1033,7 @@ class Booster(object):
"""
for d in cache:
if not isinstance(d, DMatrix):
raise TypeError('invalid cache item: {}'.format(type(d).__name__))
raise TypeError('invalid cache item: {}'.format(type(d).__name__), cache)
self._validate_features(d)
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
@@ -1353,6 +1376,10 @@ class Booster(object):
if pred_interactions:
option_mask |= 0x10
if not isinstance(data, DMatrix):
raise TypeError('Expecting data to be a DMatrix object, got: ',
type(data))
if validate_features:
self._validate_features(data)

View File

@@ -1,121 +1,609 @@
# pylint: disable=wrong-import-position,wrong-import-order,import-error
"""Dask extensions for distributed training. See xgboost/demo/dask for examples."""
import os
import math
import platform
import logging
from threading import Thread
from . import rabit
from .core import DMatrix
from .compat import (DaskDataFrame, DaskSeries, DaskArray,
distributed_get_worker)
from .tracker import RabitTracker
def _start_tracker(n_workers):
""" Start Rabit tracker """
host = distributed_get_worker().address
if '://' in host:
host = host.rsplit('://', 1)[1]
host, port = host.split(':')
port = int(port)
env = {'DMLC_NUM_WORKER': n_workers}
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
return env
def get_local_data(data):
"""
Unpacks a distributed data object to get the rows local to this worker
:param data: A distributed dask data object
:return: Local data partition e.g. numpy or pandas
"""
if isinstance(data, DaskArray):
total_partitions = len(data.chunks[0])
else:
total_partitions = data.npartitions
partition_size = int(math.ceil(total_partitions / rabit.get_world_size()))
begin_partition = partition_size * rabit.get_rank()
end_partition = min(begin_partition + partition_size, total_partitions)
if isinstance(data, DaskArray):
return data.blocks[begin_partition:end_partition].compute()
return data.partitions[begin_partition:end_partition].compute()
def create_worker_dmatrix(*args, **kwargs):
"""
Creates a DMatrix object local to a given worker. Simply forwards arguments onto the standard
DMatrix constructor, if one of the arguments is a dask dataframe, unpack the data frame to
get the local components.
All dask dataframe arguments must use the same partitioning.
:param args: DMatrix constructor args.
:return: DMatrix object containing data local to current dask worker
"""
dmatrix_args = []
dmatrix_kwargs = {}
# Convert positional args
for arg in args:
if isinstance(arg, (DaskDataFrame, DaskSeries, DaskArray)):
dmatrix_args.append(get_local_data(arg))
else:
dmatrix_args.append(arg)
# Convert keyword args
for k, v in kwargs.items():
if isinstance(v, (DaskDataFrame, DaskSeries, DaskArray)):
dmatrix_kwargs[k] = get_local_data(v)
else:
dmatrix_kwargs[k] = v
return DMatrix(*dmatrix_args, **dmatrix_kwargs)
def _run_with_rabit(rabit_args, func, *args):
worker = distributed_get_worker()
try:
os.environ["OMP_NUM_THREADS"] = str(worker.ncores)
except AttributeError:
os.environ["OMP_NUM_THREADS"] = str(worker.nthreads)
try:
rabit.init(rabit_args)
result = func(*args)
finally:
rabit.finalize()
return result
def run(client, func, *args):
"""Launch arbitrary function on dask workers. Workers are connected by rabit,
allowing distributed training. The environment variable OMP_NUM_THREADS is
defined on each worker according to dask - this means that calls to
xgb.train() will use the threads allocated by dask by default, unless the
user overrides the nthread parameter.
Note: Windows platforms are not officially
supported. Contributions are welcome here.
:param client: Dask client representing the cluster
:param func: Python function to be executed by each worker. Typically
contains xgboost training code.
:param args: Arguments to be forwarded to func
:return: Dict containing the function return value for each worker
"""
if platform.system() == 'Windows':
logging.warning('Windows is not officially supported for dask/xgboost'
'integration. Contributions welcome.')
workers = list(client.scheduler_info()['workers'].keys())
env = client.run(_start_tracker, len(workers), workers=[workers[0]])
rabit_args = [('%s=%s' % item).encode() for item in env[workers[0]].items()]
return client.run(_run_with_rabit, rabit_args, func, *args)
# pylint: disable=too-many-arguments, too-many-locals
"""Dask extensions for distributed training. See
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
tutorial. Also xgboost/demo/dask for some examples.
There are two sets of APIs in this module, one is the functional API including
``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper
inherited from single-node Scikit-Learn interface.
The implementation is heavily influenced by dask_xgboost:
https://github.com/dask/dask-xgboost
"""
import platform
import logging
from collections import defaultdict
from threading import Thread
import numpy
from . import rabit
from .compat import DASK_INSTALLED
from .compat import distributed_get_worker, distributed_wait, distributed_comm
from .compat import da, dd, delayed, get_client
from .compat import sparse, scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .core import DMatrix, Booster, _expect
from .training import train as worker_train
from .tracker import RabitTracker
from .sklearn import XGBModel, XGBClassifierBase
# Current status is considered as initial support, many features are
# not properly supported yet.
#
# TODOs:
# - Callback.
# - Label encoding.
# - CV
# - Ranking
def _start_tracker(host, n_workers):
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
return env
def _assert_dask_installed():
if not DASK_INSTALLED:
raise ImportError(
'Dask needs to be installed in order to use this module')
class RabitContext:
'''A context controling rabit initialization and finalization.'''
def __init__(self, args):
self.args = args
def __enter__(self):
rabit.init(self.args)
logging.debug('-------------- rabit say hello ------------------')
def __exit__(self, *args):
rabit.finalize()
logging.debug('--------------- rabit say bye ------------------')
def concat(value):
'''To be replaced with dask builtin.'''
if isinstance(value[0], numpy.ndarray):
return numpy.concatenate(value, axis=0)
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
return scipy_sparse.vstack(value, format='csr')
if sparse and isinstance(value[0], sparse.SparseArray):
return sparse.concatenate(value, axis=0)
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
return pandas_concat(value, axis=0)
return dd.multi.concat(list(value), axis=0)
def _xgb_get_client(client):
'''Simple wrapper around testing None.'''
ret = get_client() if client is None else client
return ret
class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes
'''DMatrix holding on references to Dask DataFrame or Dask Array.
Parameters
----------
client: dask.distributed.Client
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
data : dask.array.Array/dask.dataframe.DataFrame
data source of DMatrix.
label: dask.array.Array/dask.dataframe.DataFrame
label used for trainin.
missing : float, optional
Value in the input data (e.g. `numpy.ndarray`) which needs
to be present as a missing value. If None, defaults to np.nan.
weight : dask.array.Array/dask.dataframe.DataFrame
Weight for each instance.
feature_names : list, optional
Set names for features.
feature_types : list, optional
Set types for features
'''
_feature_names = None # for previous version's pickle
_feature_types = None
def __init__(self,
client,
data,
label=None,
missing=None,
weight=None,
feature_names=None,
feature_types=None):
_assert_dask_installed()
self._feature_names = feature_names
self._feature_types = feature_types
self._missing = missing
if len(data.shape) != 2:
_expect('2 dimensions input', data.shape)
self.n_rows = data.shape[0]
self.n_cols = data.shape[1]
if not any(isinstance(data, t) for t in (dd.DataFrame, da.Array)):
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
if not any(
isinstance(label, t)
for t in (dd.DataFrame, da.Array, dd.Series, type(None))):
raise TypeError(
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
self.worker_map = None
self.has_label = label is not None
self.has_weights = weight is not None
client = _xgb_get_client(client)
client.sync(self.map_local_data, client, data, label, weight)
async def map_local_data(self, client, data, label=None, weights=None):
'''Obtain references to local data.'''
data = data.persist()
if label is not None:
label = label.persist()
if weights is not None:
weights = weights.persist()
# Breaking data into partitions, a trick borrowed from dask_xgboost.
# `to_delayed` downgrades high-level objects into numpy or pandas
# equivalents.
X_parts = data.to_delayed()
if isinstance(X_parts, numpy.ndarray):
assert X_parts.shape[1] == 1
X_parts = X_parts.flatten().tolist()
if label is not None:
y_parts = label.to_delayed()
if isinstance(y_parts, numpy.ndarray):
assert y_parts.ndim == 1 or y_parts.shape[1] == 1
y_parts = y_parts.flatten().tolist()
if weights is not None:
w_parts = weights.to_delayed()
if isinstance(w_parts, numpy.ndarray):
assert w_parts.ndim == 1 or w_parts.shape[1] == 1
w_parts = w_parts.flatten().tolist()
parts = [X_parts]
if label is not None:
assert len(X_parts) == len(
y_parts), 'Partitions between X and y are not consistent'
parts.append(y_parts)
if weights is not None:
assert len(X_parts) == len(
w_parts), 'Partitions between X and weight are not consistent.'
parts.append(w_parts)
parts = list(map(delayed, zip(*parts)))
parts = client.compute(parts)
await distributed_wait(parts) # async wait for parts to be computed
for part in parts:
assert part.status == 'finished'
key_to_partition = {part.key: part for part in parts}
who_has = await client.scheduler.who_has(
keys=[part.key for part in parts])
worker_map = defaultdict(list)
for key, workers in who_has.items():
worker_map[next(iter(workers))].append(key_to_partition[key])
self.worker_map = worker_map
def get_worker_parts(self, worker):
'''Get mapped parts of data in each worker.'''
list_of_parts = self.worker_map[worker.address]
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
assert isinstance(list_of_parts, list)
# `get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`.
client = get_client()
list_of_parts = client.gather(list_of_parts)
if self.has_label:
if self.has_weights:
data, labels, weights = zip(*list_of_parts)
else:
data, labels = zip(*list_of_parts)
weights = None
else:
data = [d[0] for d in list_of_parts]
labels = None
weights = None
return data, labels, weights
def get_worker_data(self, worker):
'''Get data that local to worker.
Parameters
----------
worker: The worker used as key to data.
Returns
-------
A DMatrix object.
'''
data, labels, weights = self.get_worker_parts(worker)
data = concat(data)
if self.has_label:
labels = concat(labels)
else:
labels = None
if self.has_weights:
weights = concat(weights)
else:
weights = None
dmatrix = DMatrix(data,
labels,
weight=weights,
missing=self._missing,
feature_names=self._feature_names,
feature_types=self._feature_types)
return dmatrix
def get_worker_data_shape(self, worker):
'''Get the shape of data X in each worker.'''
data, _, _ = self.get_worker_parts(worker)
shapes = [d.shape for d in data]
rows = 0
cols = 0
for shape in shapes:
rows += shape[0]
cols += shape[1]
return (rows, cols)
def num_row(self):
return self.n_rows
def num_col(self):
return self.n_cols
def _get_rabit_args(worker_map, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address)
env = client.run_on_scheduler(_start_tracker, host.strip('/:'),
len(worker_map))
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
return rabit_args
# train and predict methods are supposed to be "functional", which meets the
# dask paradigm. But as a side effect, the `evals_result` in single-node API
# is no longer supported since it mutates the input parameter, and it's not
# intuitive to sync the mutation result. Therefore, a dictionary containing
# evaluation history is instead returned.
def train(client, params, dtrain, *args, evals=(), **kwargs):
'''Train XGBoost model.
Parameters
----------
client: dask.distributed.Client
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
Other parameters are the same as `xgboost.train` except for `evals_result`,
which is returned as part of function return value instead of argument.
Returns
-------
results: dict
A dictionary containing trained booster and evaluation history.
`history` field is the same as `eval_result` from `xgboost.train`.
.. code-block:: python
{'booster': xgboost.Booster,
'history': {'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}}}
'''
_assert_dask_installed()
if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.'
logging.warning(msg)
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')
client = _xgb_get_client(client)
worker_map = dtrain.worker_map
rabit_args = _get_rabit_args(worker_map, client)
def dispatched_train(worker_id):
'''Perform training on worker.'''
logging.info('Training on %d', worker_id)
worker = distributed_get_worker()
local_dtrain = dtrain.get_worker_data(worker)
local_evals = []
if evals:
for mat, name in evals:
local_mat = mat.get_worker_data(worker)
local_evals.append((local_mat, name))
with RabitContext(rabit_args):
local_history = {}
local_param = params.copy() # just to be consistent
bst = worker_train(params=local_param,
dtrain=local_dtrain,
*args,
evals_result=local_history,
evals=local_evals,
**kwargs)
ret = {'booster': bst, 'history': local_history}
if rabit.get_rank() != 0:
ret = None
return ret
futures = client.map(dispatched_train,
range(len(worker_map)),
workers=list(worker_map.keys()))
results = client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0]
def predict(client, model, data, *args):
'''Run prediction with a trained booster.
Parameters
----------
client: dask.distributed.Client
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model: A Booster or a dictionary returned by `xgboost.dask.train`.
The trained model.
data: DaskDMatrix
Input data used for prediction.
Returns
-------
prediction: dask.array.Array
'''
_assert_dask_installed()
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, DaskDMatrix):
raise TypeError(_expect([DaskDMatrix], type(data)))
worker_map = data.worker_map
client = _xgb_get_client(client)
rabit_args = _get_rabit_args(worker_map, client)
def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
logging.info('Predicting on %d', worker_id)
worker = distributed_get_worker()
local_x = data.get_worker_data(worker)
with RabitContext(rabit_args):
local_predictions = booster.predict(data=local_x, *args)
return local_predictions
futures = client.map(dispatched_predict,
range(len(worker_map)),
workers=list(worker_map.keys()))
def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
logging.info('Trying to get data shape on %d', worker_id)
worker = distributed_get_worker()
rows, cols = data.get_worker_data_shape(worker)
return rows, cols
# Constructing a dask array from list of numpy arrays
# See https://docs.dask.org/en/latest/array-creation.html
futures_shape = client.map(dispatched_get_shape,
range(len(worker_map)),
workers=list(worker_map.keys()))
shapes = client.gather(futures_shape)
arrays = []
for i in range(len(futures_shape)):
arrays.append(da.from_delayed(futures[i], shape=shapes[i],
dtype=numpy.float32))
predictions = da.concatenate(arrays, axis=0)
return predictions
def _evaluation_matrices(client, validation_set, sample_weights):
'''
Parameters
----------
validation_set: list of tuples
Each tuple contains a validation dataset including input X and label y.
E.g.:
.. code-block:: python
[(X_0, y_0), (X_1, y_1), ... ]
sample_weights: list of arrays
The weight vector for validation data.
Returns
-------
evals: list of validation DMatrix
'''
evals = []
if validation_set is not None:
assert isinstance(validation_set, list)
for i, e in enumerate(validation_set):
w = (sample_weights[i]
if sample_weights is not None else None)
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w)
evals.append((dmat, 'validation_{}'.format(i)))
else:
evals = None
return evals
class DaskScikitLearnBase(XGBModel):
'''Base class for implementing scikit-learn interface with Dask'''
_client = None
# pylint: disable=arguments-differ
def fit(self,
X,
y,
sample_weights=None,
eval_set=None,
sample_weight_eval_set=None):
'''Fit the regressor.
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
sample_weight : array_like
instance weights
eval_set : list, optional
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
of group weights on the i-th validation set.'''
raise NotImplementedError
def predict(self, data): # pylint: disable=arguments-differ
'''Predict with `data`.
Parameters
----------
data: data that can be used to construct a DaskDMatrix
Returns
-------
prediction : dask.array.Array'''
raise NotImplementedError
@property
def client(self):
'''The dask client used in this model.'''
client = _xgb_get_client(self._client)
return client
@client.setter
def client(self, clt):
self._client = clt
class DaskXGBRegressor(DaskScikitLearnBase):
# pylint: disable=missing-docstring
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
'regression. \n\n') + '\n'.join(
XGBModel.__doc__.split('\n')[2:])
def fit(self,
X,
y,
sample_weights=None,
eval_set=None,
sample_weight_eval_set=None):
_assert_dask_installed()
dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights)
params = self.get_xgb_params()
evals = _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set)
results = train(self.client, params, dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals)
self._Booster = results['booster']
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
return self
def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_installed()
test_dmatrix = DaskDMatrix(client=self.client, data=data)
pred_probs = predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
return pred_probs
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-docstring
_client = None
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
'classification.\n\n') + '\n'.join(
XGBModel.__doc__.split('\n')[2:])
def fit(self,
X,
y,
sample_weights=None,
eval_set=None,
sample_weight_eval_set=None):
_assert_dask_installed()
dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights)
params = self.get_xgb_params()
# pylint: disable=attribute-defined-outside-init
self.classes_ = da.unique(y).compute()
self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2:
params["objective"] = "multi:softprob"
params['num_class'] = self.n_classes_
else:
params["objective"] = "binary:logistic"
params.setdefault('num_class', self.n_classes_)
evals = _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set)
results = train(self.client, params, dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals)
self._Booster = results['booster']
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
return self
def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_installed()
test_dmatrix = DaskDMatrix(client=self.client, data=data)
pred_probs = predict(client=self.client,
model=self.get_booster(), data=test_dmatrix)
return pred_probs

View File

@@ -1,8 +1,6 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
"""Scikit-Learn Wrapper interface for XGBoost."""
from __future__ import absolute_import
import warnings
import json
import numpy as np
@@ -282,7 +280,8 @@ class XGBModel(XGBModelBase):
"object {} will be lost. ".format(type(self).__name__) +
"If you did not mean to export the model to " +
"a non-Python binding of XGBoost, consider " +
"using `pickle` or `joblib` to save your model.", Warning)
"using `pickle` or `joblib` to save your model.",
Warning)
self.get_booster().save_model(fname)
def load_model(self, fname):