Skip to content

Commit 58934c8

Browse files
authored
fix: count gpu uuids if NVIDIA_VISIBLE_DEVICES env set to all (#3230)
1 parent 18cbecf commit 58934c8

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

launcher/src/main.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,23 @@ fn num_cuda_devices() -> Option<usize> {
12631263
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
12641264
Ok(devices) => devices,
12651265
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
1266-
Ok(devices) => devices,
1266+
Ok(devices) => {
1267+
if devices.trim() == "all" {
1268+
// Count the number of all GPUs via nvidia-smi
1269+
let output = Command::new("nvidia-smi")
1270+
.args(["--query-gpu=uuid", "--format=csv,noheader"])
1271+
.output()
1272+
.ok()?;
1273+
1274+
String::from_utf8_lossy(&output.stdout)
1275+
.lines()
1276+
.filter(|line| !line.trim().is_empty())
1277+
.count()
1278+
.to_string()
1279+
} else {
1280+
devices
1281+
}
1282+
}
12671283
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
12681284
},
12691285
};

0 commit comments

Comments
 (0)