Fix typo in dask interface. (#6240)
This commit is contained in:
@@ -40,10 +40,13 @@ kCols = 10
|
||||
kWorkers = 5
|
||||
|
||||
|
||||
def generate_array():
|
||||
def generate_array(with_weights=False):
|
||||
partition_size = 20
|
||||
X = da.random.random((kRows, kCols), partition_size)
|
||||
y = da.random.random(kRows, partition_size)
|
||||
if with_weights:
|
||||
w = da.random.random(kRows, partition_size)
|
||||
return X, y, w
|
||||
return X, y
|
||||
|
||||
|
||||
@@ -252,11 +255,11 @@ def test_dask_missing_value_cls():
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.client = client
|
||||
regressor.fit(X, y, eval_set=[(X, y)])
|
||||
regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)])
|
||||
prediction = regressor.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
@@ -274,12 +277,12 @@ def test_dask_regressor():
|
||||
def test_dask_classifier():
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
y = (y * 10).astype(np.int32)
|
||||
classifier = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric='merror')
|
||||
classifier.client = client
|
||||
classifier.fit(X, y, eval_set=[(X, y)])
|
||||
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
|
||||
prediction = classifier.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
|
||||
Reference in New Issue
Block a user