Fix linear gpu input (#6255)

This commit is contained in:
Jiaming Yuan 2020-10-19 12:02:36 +08:00 committed by GitHub
parent cdcdab98b8
commit 5037abeb86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 3 deletions

View File

@ -231,7 +231,8 @@ class GBLinear : public GradientBooster {
// start collecting the prediction
const int ngroup = model_.learner_model_param->num_output_group;
preds.resize(p_fmat->Info().num_row_ * ngroup);
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
for (const auto &page : p_fmat->GetBatches<SparsePage>()) {
auto const& batch = page.GetView();
// output convention: nrow * k, where nrow is number of rows
// k is number of group
// parallel over local batch
@ -241,7 +242,7 @@ class GBLinear : public GradientBooster {
}
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < nsize; ++i) {
const size_t ridx = batch.base_rowid + i;
const size_t ridx = page.base_rowid + i;
// loop over output groups
for (int gid = 0; gid < ngroup; ++gid) {
bst_float margin =

View File

@ -44,6 +44,5 @@ TEST(GBLinear, JsonIO) {
ASSERT_EQ(weights.size(), 17);
}
}
} // namespace gbm
} // namespace xgboost

View File

@ -1,5 +1,7 @@
import sys
from hypothesis import strategies, given, settings, assume
import pytest
import numpy
import xgboost as xgb
sys.path.append("tests/python")
import testing as tm
@ -48,3 +50,22 @@ class TestGPULinear:
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)['train'][dataset.metric]
assert tm.non_increasing([result[0], result[-1]])
@pytest.mark.skipif(**tm.no_cupy())
def test_gpu_coordinate_from_cupy(self):
# Training linear model is quite expensive, so we don't include it in
# test_from_cupy.py
import cupy
params = {'booster': 'gblinear', 'updater': 'gpu_coord_descent',
'n_estimators': 100}
X, y = tm.get_boston()
cpu_model = xgb.XGBRegressor(**params)
cpu_model.fit(X, y)
cpu_predt = cpu_model.predict(X)
X = cupy.array(X)
y = cupy.array(y)
gpu_model = xgb.XGBRegressor(**params)
gpu_model.fit(X, y)
gpu_predt = gpu_model.predict(X)
cupy.testing.assert_allclose(cpu_predt, gpu_predt)