[pyspark] Enable running GPU tests on variable number of GPUs. (#8335)
This commit is contained in:
parent
4633b476e9
commit
748d516c50
@ -1,3 +1,16 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}"
|
# This script is only made for running XGBoost tests on official CI where we have access
|
||||||
|
# to a 4-GPU cluster, the discovery command is for running tests on a local machine where
|
||||||
|
# the driver and the GPU worker might be the same machine for the ease of development.
|
||||||
|
|
||||||
|
if ! command -v nvidia-smi &> /dev/null
|
||||||
|
then
|
||||||
|
# default to 4 GPUs
|
||||||
|
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}"
|
||||||
|
exit
|
||||||
|
else
|
||||||
|
# https://github.com/apache/spark/blob/master/examples/src/main/scripts/getGpusResources.sh
|
||||||
|
ADDRS=`nvidia-smi --query-gpu=index --format=csv,noheader | sed -e ':a' -e 'N' -e'$!ba' -e 's/\n/","/g'`
|
||||||
|
echo {\"name\": \"gpu\", \"addresses\":[\"$ADDRS\"]}
|
||||||
|
fi
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -18,8 +20,20 @@ from pyspark.sql import SparkSession
|
|||||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||||
|
|
||||||
gpu_discovery_script_path = "tests/python-gpu/test_gpu_spark/discover_gpu.sh"
|
gpu_discovery_script_path = "tests/python-gpu/test_gpu_spark/discover_gpu.sh"
|
||||||
executor_gpu_amount = 4
|
|
||||||
executor_cores = 4
|
|
||||||
|
def get_devices():
|
||||||
|
"""This works only if driver is the same machine of worker."""
|
||||||
|
completed = subprocess.run(gpu_discovery_script_path, stdout=subprocess.PIPE)
|
||||||
|
assert completed.returncode == 0, "Failed to execute discovery script."
|
||||||
|
msg = completed.stdout.decode("utf-8")
|
||||||
|
result = json.loads(msg)
|
||||||
|
addresses = result["addresses"]
|
||||||
|
return addresses
|
||||||
|
|
||||||
|
|
||||||
|
executor_gpu_amount = len(get_devices())
|
||||||
|
executor_cores = executor_gpu_amount
|
||||||
num_workers = executor_gpu_amount
|
num_workers = executor_gpu_amount
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user