[dask] dask cudf inplace prediction. (#5512)
* Add inplace prediction for dask-cudf. * Remove Dockerfile.release, since it's not used anywhere * Use Conda exclusively in CUDF and GPU containers * Improve cupy memory copying. * Add skip marks to tests. * Add mgpu-cudf category on the CI to run all distributed tests. Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
ca4e05660e
commit
8b04736b81
18
Jenkinsfile
vendored
18
Jenkinsfile
vendored
@ -307,19 +307,25 @@ def TestPythonGPU(args) {
|
||||
sh """
|
||||
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh mgpu
|
||||
"""
|
||||
if (args.cuda_version != '9.0') {
|
||||
echo "Running tests with cuDF..."
|
||||
sh """
|
||||
${dockerRun} cudf ${docker_binary} ${docker_args} tests/ci_build/test_python.sh mgpu-cudf
|
||||
"""
|
||||
}
|
||||
} else {
|
||||
echo "Using a single GPU"
|
||||
sh """
|
||||
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh gpu
|
||||
"""
|
||||
if (args.cuda_version != '9.0') {
|
||||
echo "Running tests with cuDF..."
|
||||
sh """
|
||||
${dockerRun} cudf ${docker_binary} ${docker_args} tests/ci_build/test_python.sh cudf
|
||||
"""
|
||||
}
|
||||
}
|
||||
// For CUDA 10.0 target, run cuDF tests too
|
||||
if (args.cuda_version == '10.0') {
|
||||
echo "Running tests with cuDF..."
|
||||
sh """
|
||||
${dockerRun} cudf ${docker_binary} ${docker_args} tests/ci_build/test_python.sh cudf
|
||||
"""
|
||||
}
|
||||
deleteDir()
|
||||
}
|
||||
}
|
||||
|
||||
@ -209,15 +209,32 @@ def ctypes2numpy(cptr, length, dtype):
|
||||
|
||||
def ctypes2cupy(cptr, length, dtype):
|
||||
"""Convert a ctypes pointer array to a cupy array."""
|
||||
import cupy # pylint: disable=import-error
|
||||
mem = cupy.zeros(length.value, dtype=dtype, order='C')
|
||||
# pylint: disable=import-error
|
||||
import cupy
|
||||
from cupy.cuda.memory import MemoryPointer
|
||||
from cupy.cuda.memory import UnownedMemory
|
||||
CUPY_TO_CTYPES_MAPPING = {
|
||||
cupy.float32: ctypes.c_float,
|
||||
cupy.uint32: ctypes.c_uint
|
||||
}
|
||||
if dtype not in CUPY_TO_CTYPES_MAPPING.keys():
|
||||
raise RuntimeError('Supported types: {}'.format(
|
||||
CUPY_TO_CTYPES_MAPPING.keys()
|
||||
))
|
||||
addr = ctypes.cast(cptr, ctypes.c_void_p).value
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
cupy.cuda.runtime.memcpy(
|
||||
mem.__cuda_array_interface__['data'][0], addr,
|
||||
length.value * ctypes.sizeof(ctypes.c_float),
|
||||
cupy.cuda.runtime.memcpyDeviceToDevice)
|
||||
return mem
|
||||
device = cupy.cuda.runtime.pointerGetAttributes(addr).device
|
||||
# The owner field is just used to keep the memory alive with ref count. As
|
||||
# unowned's life time is scoped within this function we don't need that.
|
||||
unownd = UnownedMemory(
|
||||
addr, length.value * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]),
|
||||
owner=None)
|
||||
memptr = MemoryPointer(unownd, 0)
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
mem = cupy.ndarray((length.value, ), dtype=dtype, memptr=memptr)
|
||||
assert mem.device.id == device
|
||||
arr = cupy.array(mem, copy=True)
|
||||
return arr
|
||||
|
||||
|
||||
def ctypes2buffer(cptr, length):
|
||||
|
||||
@ -101,6 +101,11 @@ def concat(value): # pylint: disable=too-many-return-statements
|
||||
return CUDF_concat(value, axis=0)
|
||||
if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'):
|
||||
import cupy # pylint: disable=import-error
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
d = cupy.cuda.runtime.getDevice()
|
||||
for v in value:
|
||||
d_v = v.device.id
|
||||
assert d_v == d, 'Concatenating arrays on different devices.'
|
||||
return cupy.concatenate(value, axis=0)
|
||||
return dd.multi.concat(list(value), axis=0)
|
||||
|
||||
@ -631,8 +636,6 @@ def inplace_predict(client, model, data,
|
||||
if is_df:
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
# There's an error with cudf saying `concat_cudf` got an
|
||||
# expected argument `ignore_index`. So this is not yet working.
|
||||
prediction = cudf.DataFrame({'prediction': prediction},
|
||||
dtype=numpy.float32)
|
||||
else:
|
||||
|
||||
@ -12,8 +12,8 @@ RUN \
|
||||
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
|
||||
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr && \
|
||||
# Python
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/python
|
||||
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3.sh -b -p /opt/python
|
||||
|
||||
ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
|
||||
@ -10,20 +10,16 @@ RUN \
|
||||
apt-get update && \
|
||||
apt-get install -y wget unzip bzip2 libgomp1 build-essential && \
|
||||
# Python
|
||||
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
|
||||
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python
|
||||
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3.sh -b -p /opt/python
|
||||
|
||||
ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
# Create new Conda environment with cuDF and dask
|
||||
# Create new Conda environment with cuDF, Dask, and cuPy
|
||||
RUN \
|
||||
conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \
|
||||
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda cupy
|
||||
|
||||
# Install other Python packages
|
||||
RUN \
|
||||
source activate cudf_test && \
|
||||
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz
|
||||
conda create -n cudf_test -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||
python=3.7 cudf cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \
|
||||
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
|
||||
@ -9,16 +9,16 @@ RUN \
|
||||
apt-get update && \
|
||||
apt-get install -y wget unzip bzip2 libgomp1 build-essential && \
|
||||
# Python
|
||||
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
|
||||
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python
|
||||
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3.sh -b -p /opt/python
|
||||
|
||||
ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
# Install Python packages
|
||||
RUN \
|
||||
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
|
||||
pip install "dask[complete]" && \
|
||||
conda install -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda dask-cuda
|
||||
conda create -n gpu_test -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||
python=3.7 dask dask-cuda numpy pytest scipy scikit-learn pandas \
|
||||
matplotlib wheel python-kubernetes urllib3 graphviz
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
|
||||
@ -17,8 +17,8 @@ RUN \
|
||||
$DEVTOOLSET_URL_ROOT/devtoolset-4-runtime-4.1-3.sc1.el6.x86_64.rpm \
|
||||
$DEVTOOLSET_URL_ROOT/devtoolset-4-libstdc++-devel-5.3.1-6.1.el6.x86_64.rpm && \
|
||||
# Python
|
||||
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
|
||||
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python && \
|
||||
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3.sh -b -p /opt/python && \
|
||||
# CMake
|
||||
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
|
||||
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr
|
||||
|
||||
@ -8,8 +8,8 @@ RUN \
|
||||
yum -y update && \
|
||||
yum install -y devtoolset-6-gcc devtoolset-6-binutils devtoolset-6-gcc-c++ && \
|
||||
# Python
|
||||
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
|
||||
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python && \
|
||||
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3.sh -b -p /opt/python && \
|
||||
# CMake
|
||||
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
|
||||
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr && \
|
||||
|
||||
@ -13,8 +13,8 @@ RUN \
|
||||
apt-get update && \
|
||||
apt-get install -y tar unzip wget openjdk-$JDK_VERSION-jdk libgomp1 && \
|
||||
# Python
|
||||
wget https://repo.continuum.io/miniconda/Miniconda3-4.5.12-Linux-x86_64.sh && \
|
||||
bash Miniconda3-4.5.12-Linux-x86_64.sh -b -p /opt/python && \
|
||||
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3.sh -b -p /opt/python && \
|
||||
/opt/python/bin/pip install awscli && \
|
||||
# Maven
|
||||
wget https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
FROM centos:6
|
||||
|
||||
# Install all basic requirements
|
||||
RUN \
|
||||
yum -y update && \
|
||||
yum install -y graphviz tar unzip wget xz git && \
|
||||
# Python
|
||||
wget https://repo.continuum.io/miniconda/Miniconda2-4.3.27-Linux-x86_64.sh && \
|
||||
bash Miniconda2-4.3.27-Linux-x86_64.sh -b -p /opt/python
|
||||
|
||||
ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
# Install Python packages
|
||||
RUN \
|
||||
conda install numpy scipy pandas matplotlib pytest scikit-learn && \
|
||||
pip install pytest wheel auditwheel graphviz
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
# Install lightweight sudo (not bound to TTY)
|
||||
RUN set -ex; \
|
||||
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
|
||||
chmod +x /usr/local/bin/gosu && \
|
||||
gosu nobody true
|
||||
|
||||
# Default entry-point to use if running locally
|
||||
# It will preserve attributes of created files
|
||||
COPY entrypoint.sh /scripts/
|
||||
|
||||
WORKDIR /workspace
|
||||
ENTRYPOINT ["/scripts/entrypoint.sh"]
|
||||
@ -28,23 +28,33 @@ function install_xgboost {
|
||||
# Run specified test suite
|
||||
case "$suite" in
|
||||
gpu)
|
||||
source activate gpu_test
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu
|
||||
pytest -v -s -rxXs --fulltrace -m "not mgpu" tests/python-gpu
|
||||
;;
|
||||
|
||||
mgpu)
|
||||
source activate gpu_test
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu
|
||||
pytest -v -s -rxXs --fulltrace -m "mgpu" tests/python-gpu
|
||||
|
||||
cd tests/distributed
|
||||
./runtests-gpu.sh
|
||||
cd -
|
||||
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu/test_gpu_with_dask.py
|
||||
;;
|
||||
|
||||
cudf)
|
||||
source activate cudf_test
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_cudf.py tests/python-gpu/test_from_cupy.py
|
||||
pytest -v -s -rxXs --fulltrace -m "not mgpu" \
|
||||
tests/python-gpu/test_from_cudf.py tests/python-gpu/test_from_cupy.py \
|
||||
tests/python-gpu/test_gpu_prediction.py
|
||||
;;
|
||||
|
||||
mgpu-cudf)
|
||||
source activate cudf_test
|
||||
install_xgboost
|
||||
pytest -v -s -rxXs --fulltrace -m "mgpu" tests/python-gpu/test_gpu_with_dask.py
|
||||
;;
|
||||
|
||||
cpu)
|
||||
|
||||
@ -62,6 +62,7 @@ class TestGPUPredict(unittest.TestCase):
|
||||
|
||||
# Test case for a bug where multiple batch predictions made on a
|
||||
# test set produce incorrect results
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_multi_predict(self):
|
||||
from sklearn.datasets import make_regression
|
||||
from sklearn.model_selection import train_test_split
|
||||
@ -89,6 +90,7 @@ class TestGPUPredict(unittest.TestCase):
|
||||
assert np.allclose(predict0, predict1)
|
||||
assert np.allclose(predict0, cpu_predict)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_sklearn(self):
|
||||
m, n = 15000, 14
|
||||
tr_size = 2500
|
||||
|
||||
@ -27,6 +27,7 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
def test_dask_dataframe(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
@ -51,18 +52,18 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
predictions = dxgb.predict(client, out, dtrain).compute()
|
||||
assert isinstance(predictions, np.ndarray)
|
||||
|
||||
# There's an error with cudf saying `concat_cudf` got an
|
||||
# expected argument `ignore_index`. So the test here is just
|
||||
# place holder.
|
||||
|
||||
# series_predictions = dxgb.inplace_predict(client, out, X)
|
||||
# assert isinstance(series_predictions, dd.Series)
|
||||
series_predictions = dxgb.inplace_predict(client, out, X)
|
||||
assert isinstance(series_predictions, dd.Series)
|
||||
series_predictions = series_predictions.compute()
|
||||
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
|
||||
cupy.testing.assert_allclose(single_node, predictions)
|
||||
cupy.testing.assert_allclose(single_node, series_predictions)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.mgpu
|
||||
def test_dask_array(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
@ -82,8 +83,12 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
single_node = out['booster'].predict(
|
||||
xgboost.DMatrix(X.compute()))
|
||||
np.testing.assert_allclose(single_node, from_dmatrix)
|
||||
device = cupy.cuda.runtime.getDevice()
|
||||
assert device == inplace_predictions.device.id
|
||||
single_node = cupy.array(single_node)
|
||||
assert device == single_node.device.id
|
||||
cupy.testing.assert_allclose(
|
||||
cupy.array(single_node),
|
||||
single_node,
|
||||
inplace_predictions)
|
||||
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
from sklearn.datasets import make_regression
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@ -20,6 +20,7 @@ def non_increasing(L):
|
||||
|
||||
|
||||
def assert_constraint(constraint, tree_method):
|
||||
from sklearn.datasets import make_regression
|
||||
n = 1000
|
||||
X, y = make_regression(n, random_state=rng, n_features=1, n_informative=1)
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
@ -35,12 +36,13 @@ def assert_constraint(constraint, tree_method):
|
||||
assert non_increasing(pred)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestMonotonicConstraints(unittest.TestCase):
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_exact(self):
|
||||
assert_constraint(1, 'exact')
|
||||
assert_constraint(-1, 'exact')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_gpu_hist(self):
|
||||
assert_constraint(1, 'gpu_hist')
|
||||
assert_constraint(-1, 'gpu_hist')
|
||||
|
||||
@ -12,10 +12,10 @@ def run_threaded_predict(X, rows, predict_func):
|
||||
per_thread = 20
|
||||
with ThreadPoolExecutor(max_workers=10) as e:
|
||||
for i in range(0, rows, int(rows / per_thread)):
|
||||
try:
|
||||
if hasattr(X, 'iloc'):
|
||||
predictor = X.iloc[i:i+per_thread, :]
|
||||
else:
|
||||
predictor = X[i:i+per_thread, ...]
|
||||
except TypeError:
|
||||
predictor = X.iloc[i:i+per_thread, ...]
|
||||
f = e.submit(predict_func, predictor)
|
||||
results.append(f)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user