Skip to content

Commit 7945f7d

Browse files
committed
Filter out malformed nvidia-smi process_name XML tag
1 parent 20f01e0 commit 7945f7d

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

cwltool/cuda.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,37 @@
1010

1111
def cuda_version_and_device_count() -> Tuple[str, int]:
1212
"""Determine the CUDA version and number of attached CUDA GPUs."""
13+
# For the number of GPUs, we can use the following query
14+
cmd_count = ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader"]
1315
try:
14-
out = subprocess.check_output(["nvidia-smi", "-q", "-x"]) # nosec
16+
out_count = subprocess.check_output(cmd_count) # nosec
17+
except Exception as e:
18+
_logger.warning("Error checking number of GPUs with nvidia-smi: %s", e)
19+
return ("", 0)
20+
count = int(out_count)
21+
22+
# Since there is no specific query for the cuda version, we have to use
23+
# `nvidia-smi -q -x`
24+
# However, apparently nvidia-smi is not safe to call concurrently.
25+
# With --parallel, sometimes the returned XML will contain
26+
# <process_name>\xff...\xff</process_name>
27+
# (or other arbitrary bytes) and xml.dom.minidom.parseString will raise
28+
# "xml.parsers.expat.ExpatError: not well-formed (invalid token)"
29+
# So we either need to fix the process_name tag, or better yet specifically
30+
# `grep cuda_version`
31+
cmd_cuda_version = "nvidia-smi -q -x | grep cuda_version"
32+
try:
33+
out = subprocess.check_output(cmd_cuda_version, shell=True) # nosec
1534
except Exception as e:
1635
_logger.warning("Error checking CUDA version with nvidia-smi: %s", e)
1736
return ("", 0)
18-
dm = xml.dom.minidom.parseString(out) # nosec
1937

20-
ag = dm.getElementsByTagName("attached_gpus")
21-
if len(ag) < 1 or ag[0].firstChild is None:
22-
_logger.warning(
23-
"Error checking CUDA version with nvidia-smi. Missing 'attached_gpus' or it is empty.: %s",
24-
out,
25-
)
38+
try:
39+
dm = xml.dom.minidom.parseString(out) # nosec
40+
except xml.parsers.expat.ExpatError as e:
41+
_logger.warning("Error parsing XML stdout of nvidia-smi: %s", e)
42+
_logger.warning("stdout: %s", out)
2643
return ("", 0)
27-
ag_element = ag[0].firstChild
2844

2945
cv = dm.getElementsByTagName("cuda_version")
3046
if len(cv) < 1 or cv[0].firstChild is None:
@@ -35,13 +51,11 @@ def cuda_version_and_device_count() -> Tuple[str, int]:
3551
return ("", 0)
3652
cv_element = cv[0].firstChild
3753

38-
if isinstance(cv_element, xml.dom.minidom.Text) and isinstance(
39-
ag_element, xml.dom.minidom.Text
40-
):
41-
return (cv_element.data, int(ag_element.data))
54+
if isinstance(cv_element, xml.dom.minidom.Text):
55+
return (cv_element.data, count)
4256
_logger.warning(
4357
"Error checking CUDA version with nvidia-smi. "
44-
"Either 'attached_gpus' or 'cuda_version' was not a text node: %s",
58+
"'cuda_version' was not a text node: %s",
4559
out,
4660
)
4761
return ("", 0)

0 commit comments

Comments
 (0)