From 443ff746e9723dcf571769b0d6ea28fbcb3e4a3f Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 4 Apr 2018 15:08:22 +1200 Subject: [PATCH] Fix logic in GPU predictor cache lookup (#3217) * Fix logic in GPU predictor cache lookup * Add sklearn test for GPU prediction --- src/predictor/gpu_predictor.cu | 2 +- tests/python-gpu/test_gpu_prediction.py | 28 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index ff8ecfbdf..07469caeb 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -267,7 +267,7 @@ class GPUPredictor : public xgboost::Predictor { std::shared_ptr device_matrix; // Matrix is not in host cache, create a temporary matrix - if (this->cache_.find(dmat) != this->cache_.end()) { + if (this->cache_.find(dmat) == this->cache_.end()) { device_matrix = std::shared_ptr( new DeviceMatrix(dmat, param.gpu_id, param.silent)); } else { diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index 0b8d5b0ef..07e86d8de 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -72,3 +72,31 @@ class TestGPUPredict(unittest.TestCase): assert np.allclose(predict0, predict1) assert np.allclose(predict0, cpu_predict) + + def test_sklearn(self): + m, n = 15000, 14 + tr_size = 2500 + X = np.random.rand(m, n) + y = 200 * np.matmul(X, np.arange(-3, -3 + n)) + X_train, y_train = X[:tr_size, :], y[:tr_size] + X_test, y_test = X[tr_size:, :], y[tr_size:] + + # First with cpu_predictor + params = {'tree_method': 'gpu_hist', + 'predictor': 'cpu_predictor', + 'n_jobs': -1, + 'seed': 123 + } + m = xgb.XGBRegressor(**params).fit(X_train, y_train) + cpu_train_score = m.score(X_train, y_train) + cpu_test_score = m.score(X_test, y_test) + + # Now with gpu_predictor + params['predictor'] = 'gpu_predictor' + + m = xgb.XGBRegressor(**params).fit(X_train, y_train) + gpu_train_score = m.score(X_train, y_train) + gpu_test_score = m.score(X_test, y_test) + + assert np.allclose(cpu_train_score, gpu_train_score) + assert np.allclose(cpu_test_score, gpu_test_score)