[pyspark] Enable running GPU tests on variable number of GPUs. (#8335)
This commit is contained in:
parent
4633b476e9
commit
748d516c50
@ -1,3 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}"
|
||||
# This script is only made for running XGBoost tests on official CI where we have access
|
||||
# to a 4-GPU cluster, the discovery command is for running tests on a local machine where
|
||||
# the driver and the GPU worker might be the same machine for the ease of development.
|
||||
|
||||
if ! command -v nvidia-smi &> /dev/null
|
||||
then
|
||||
# default to 4 GPUs
|
||||
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\",\"2\",\"3\"]}"
|
||||
exit
|
||||
else
|
||||
# https://github.com/apache/spark/blob/master/examples/src/main/scripts/getGpusResources.sh
|
||||
ADDRS=`nvidia-smi --query-gpu=index --format=csv,noheader | sed -e ':a' -e 'N' -e'$!ba' -e 's/\n/","/g'`
|
||||
echo {\"name\": \"gpu\", \"addresses\":[\"$ADDRS\"]}
|
||||
fi
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
@ -18,8 +20,20 @@ from pyspark.sql import SparkSession
|
||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||
|
||||
gpu_discovery_script_path = "tests/python-gpu/test_gpu_spark/discover_gpu.sh"
|
||||
executor_gpu_amount = 4
|
||||
executor_cores = 4
|
||||
|
||||
|
||||
def get_devices():
|
||||
"""This works only if driver is the same machine of worker."""
|
||||
completed = subprocess.run(gpu_discovery_script_path, stdout=subprocess.PIPE)
|
||||
assert completed.returncode == 0, "Failed to execute discovery script."
|
||||
msg = completed.stdout.decode("utf-8")
|
||||
result = json.loads(msg)
|
||||
addresses = result["addresses"]
|
||||
return addresses
|
||||
|
||||
|
||||
executor_gpu_amount = len(get_devices())
|
||||
executor_cores = executor_gpu_amount
|
||||
num_workers = executor_gpu_amount
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user