Add new skl model attribute for number of features (#5780)
This commit is contained in:
parent
d39da42e69
commit
ae18a094b0
@ -499,6 +499,8 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
[xgb.callback.reset_learning_rate(custom_rates)]
|
||||||
"""
|
"""
|
||||||
|
self.n_features_in_ = X.shape[1]
|
||||||
|
|
||||||
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
@ -812,7 +814,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
# different ways of reshaping
|
# different ways of reshaping
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Please reshape the input data X into 2-dimensional matrix.')
|
'Please reshape the input data X into 2-dimensional matrix.')
|
||||||
|
|
||||||
self._features_count = X.shape[1]
|
self._features_count = X.shape[1]
|
||||||
|
self.n_features_in_ = self._features_count
|
||||||
|
|
||||||
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
@ -1195,6 +1200,8 @@ class XGBRanker(XGBModel):
|
|||||||
ret.set_group(group)
|
ret.set_group(group)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
self.n_features_in_ = X.shape[1]
|
||||||
|
|
||||||
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
|
|||||||
@ -115,6 +115,51 @@ def test_ranking():
|
|||||||
np.testing.assert_almost_equal(pred, pred_orig)
|
np.testing.assert_almost_equal(pred, pred_orig)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stacking_regression():
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.datasets import load_diabetes
|
||||||
|
from sklearn.linear_model import RidgeCV
|
||||||
|
from sklearn.ensemble import RandomForestRegressor
|
||||||
|
from sklearn.ensemble import StackingRegressor
|
||||||
|
|
||||||
|
X, y = load_diabetes(return_X_y=True)
|
||||||
|
estimators = [
|
||||||
|
('gbm', xgb.sklearn.XGBRegressor(objective='reg:squarederror')),
|
||||||
|
('lr', RidgeCV())
|
||||||
|
]
|
||||||
|
reg = StackingRegressor(
|
||||||
|
estimators=estimators,
|
||||||
|
final_estimator=RandomForestRegressor(n_estimators=10,
|
||||||
|
random_state=42)
|
||||||
|
)
|
||||||
|
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
|
||||||
|
reg.fit(X_train, y_train).score(X_test, y_test)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stacking_classification():
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.svm import LinearSVC
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from sklearn.pipeline import make_pipeline
|
||||||
|
from sklearn.ensemble import StackingClassifier
|
||||||
|
|
||||||
|
X, y = load_iris(return_X_y=True)
|
||||||
|
estimators = [
|
||||||
|
('gbm', xgb.sklearn.XGBClassifier()),
|
||||||
|
('svr', make_pipeline(StandardScaler(),
|
||||||
|
LinearSVC(random_state=42)))
|
||||||
|
]
|
||||||
|
clf = StackingClassifier(
|
||||||
|
estimators=estimators, final_estimator=LogisticRegression()
|
||||||
|
)
|
||||||
|
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
|
||||||
|
clf.fit(X_train, y_train).score(X_test, y_test)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_pandas())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_feature_importances_weight():
|
def test_feature_importances_weight():
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user