Update sklearn API to pass along n_jobs to DMatrix creation (#2658)

This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney 2017-08-30 23:24:59 -04:00 committed by Rory Mitchell
parent 19a53814ce
commit 0664298bb2

View File

@ -257,13 +257,15 @@ class XGBModel(XGBModelBase):
metric measured on the validation set to stderr. metric measured on the validation set to stderr.
""" """
if sample_weight is not None: if sample_weight is not None:
trainDmatrix = DMatrix(X, label=y, weight=sample_weight, missing=self.missing) trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
missing=self.missing, nthread=self.n_jobs)
else: else:
trainDmatrix = DMatrix(X, label=y, missing=self.missing) trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
evals_result = {} evals_result = {}
if eval_set is not None: if eval_set is not None:
evals = list(DMatrix(x[0], label=x[1], missing=self.missing) for x in eval_set) evals = list(DMatrix(x[0], label=x[1], missing=self.missing,
nthread=self.n_jobs) for x in eval_set)
evals = list(zip(evals, ["validation_{}".format(i) for i in evals = list(zip(evals, ["validation_{}".format(i) for i in
range(len(evals))])) range(len(evals))]))
else: else:
@ -304,7 +306,7 @@ class XGBModel(XGBModelBase):
def predict(self, data, output_margin=False, ntree_limit=0): def predict(self, data, output_margin=False, ntree_limit=0):
# pylint: disable=missing-docstring,invalid-name # pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(data, missing=self.missing) test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
return self.get_booster().predict(test_dmatrix, return self.get_booster().predict(test_dmatrix,
output_margin=output_margin, output_margin=output_margin,
ntree_limit=ntree_limit) ntree_limit=ntree_limit)
@ -327,7 +329,7 @@ class XGBModel(XGBModelBase):
leaf x ends up in. Leaves are numbered within leaf x ends up in. Leaves are numbered within
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering. ``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
""" """
test_dmatrix = DMatrix(X, missing=self.missing) test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs)
return self.get_booster().predict(test_dmatrix, return self.get_booster().predict(test_dmatrix,
pred_leaf=True, pred_leaf=True,
ntree_limit=ntree_limit) ntree_limit=ntree_limit)
@ -475,7 +477,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if eval_set is not None: if eval_set is not None:
# TODO: use sample_weight if given? # TODO: use sample_weight if given?
evals = list( evals = list(
DMatrix(x[0], label=self._le.transform(x[1]), missing=self.missing) DMatrix(x[0], label=self._le.transform(x[1]),
missing=self.missing, nthread=self.n_jobs)
for x in eval_set for x in eval_set
) )
nevals = len(evals) nevals = len(evals)
@ -488,10 +491,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if sample_weight is not None: if sample_weight is not None:
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight, train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
missing=self.missing) missing=self.missing, nthread=self.n_jobs)
else: else:
train_dmatrix = DMatrix(X, label=training_labels, train_dmatrix = DMatrix(X, label=training_labels,
missing=self.missing) missing=self.missing, nthread=self.n_jobs)
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators, self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
evals=evals, evals=evals,
@ -514,7 +517,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self return self
def predict(self, data, output_margin=False, ntree_limit=0): def predict(self, data, output_margin=False, ntree_limit=0):
test_dmatrix = DMatrix(data, missing=self.missing) test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
class_probs = self.get_booster().predict(test_dmatrix, class_probs = self.get_booster().predict(test_dmatrix,
output_margin=output_margin, output_margin=output_margin,
ntree_limit=ntree_limit) ntree_limit=ntree_limit)
@ -526,7 +529,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)
def predict_proba(self, data, output_margin=False, ntree_limit=0): def predict_proba(self, data, output_margin=False, ntree_limit=0):
test_dmatrix = DMatrix(data, missing=self.missing) test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
class_probs = self.get_booster().predict(test_dmatrix, class_probs = self.get_booster().predict(test_dmatrix,
output_margin=output_margin, output_margin=output_margin,
ntree_limit=ntree_limit) ntree_limit=ntree_limit)