Update sklearn API to pass along n_jobs to DMatrix creation (#2658)
This commit is contained in:
parent
19a53814ce
commit
0664298bb2
@ -257,13 +257,15 @@ class XGBModel(XGBModelBase):
|
|||||||
metric measured on the validation set to stderr.
|
metric measured on the validation set to stderr.
|
||||||
"""
|
"""
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
trainDmatrix = DMatrix(X, label=y, weight=sample_weight, missing=self.missing)
|
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
else:
|
else:
|
||||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing)
|
trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
evals = list(DMatrix(x[0], label=x[1], missing=self.missing) for x in eval_set)
|
evals = list(DMatrix(x[0], label=x[1], missing=self.missing,
|
||||||
|
nthread=self.n_jobs) for x in eval_set)
|
||||||
evals = list(zip(evals, ["validation_{}".format(i) for i in
|
evals = list(zip(evals, ["validation_{}".format(i) for i in
|
||||||
range(len(evals))]))
|
range(len(evals))]))
|
||||||
else:
|
else:
|
||||||
@ -304,7 +306,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
# pylint: disable=missing-docstring,invalid-name
|
# pylint: disable=missing-docstring,invalid-name
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||||
return self.get_booster().predict(test_dmatrix,
|
return self.get_booster().predict(test_dmatrix,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
ntree_limit=ntree_limit)
|
ntree_limit=ntree_limit)
|
||||||
@ -327,7 +329,7 @@ class XGBModel(XGBModelBase):
|
|||||||
leaf x ends up in. Leaves are numbered within
|
leaf x ends up in. Leaves are numbered within
|
||||||
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
|
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
|
||||||
"""
|
"""
|
||||||
test_dmatrix = DMatrix(X, missing=self.missing)
|
test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs)
|
||||||
return self.get_booster().predict(test_dmatrix,
|
return self.get_booster().predict(test_dmatrix,
|
||||||
pred_leaf=True,
|
pred_leaf=True,
|
||||||
ntree_limit=ntree_limit)
|
ntree_limit=ntree_limit)
|
||||||
@ -475,7 +477,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
# TODO: use sample_weight if given?
|
# TODO: use sample_weight if given?
|
||||||
evals = list(
|
evals = list(
|
||||||
DMatrix(x[0], label=self._le.transform(x[1]), missing=self.missing)
|
DMatrix(x[0], label=self._le.transform(x[1]),
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
for x in eval_set
|
for x in eval_set
|
||||||
)
|
)
|
||||||
nevals = len(evals)
|
nevals = len(evals)
|
||||||
@ -488,10 +491,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
||||||
missing=self.missing)
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
else:
|
else:
|
||||||
train_dmatrix = DMatrix(X, label=training_labels,
|
train_dmatrix = DMatrix(X, label=training_labels,
|
||||||
missing=self.missing)
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
|
|
||||||
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
||||||
evals=evals,
|
evals=evals,
|
||||||
@ -514,7 +517,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||||
class_probs = self.get_booster().predict(test_dmatrix,
|
class_probs = self.get_booster().predict(test_dmatrix,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
ntree_limit=ntree_limit)
|
ntree_limit=ntree_limit)
|
||||||
@ -526,7 +529,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
return self._le.inverse_transform(column_indexes)
|
return self._le.inverse_transform(column_indexes)
|
||||||
|
|
||||||
def predict_proba(self, data, output_margin=False, ntree_limit=0):
|
def predict_proba(self, data, output_margin=False, ntree_limit=0):
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||||
class_probs = self.get_booster().predict(test_dmatrix,
|
class_probs = self.get_booster().predict(test_dmatrix,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
ntree_limit=ntree_limit)
|
ntree_limit=ntree_limit)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user