Skip to content

Commit 8f89cdc

Browse files
bedpostx and probtrackx gpu and multithread support
1 parent 2f85d92 commit 8f89cdc

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

nipype/interfaces/fsl/dti.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import os
88
import warnings
9+
from shutil import which
910

1011
from ...utils.filemanip import fname_presuffix, split_filename, copyfile
1112
from ..base import (
@@ -383,6 +384,7 @@ class BEDPOSTX5InputSpec(FSLXCommandInputSpec):
383384
)
384385
grad_dev = File(exists=True, desc="grad_dev file, if gradnonlin, -g is True")
385386
use_gpu = traits.Bool(False, desc="Use the GPU version of bedpostx")
387+
num_threads = traits.Int(nohash=True, desc="Number of threads to use")
386388

387389

388390
class BEDPOSTX5OutputSpec(TraitedSpec):
@@ -451,13 +453,25 @@ class BEDPOSTX5(FSLXCommand):
451453
def __init__(self, **inputs):
452454
super().__init__(**inputs)
453455
self.inputs.on_trait_change(self._cuda_update, "use_gpu")
456+
self.inputs.on_trait_change(self._num_threads_update, "num_threads")
454457

455458
def _cuda_update(self):
456-
if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu:
459+
if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("bedpostx_gpu") is not None:
457460
self._cmd = "bedpostx_gpu"
461+
self.inputs.num_threads = 1
458462
else:
459463
self._cmd = self._default_cmd
460464

465+
def _num_threads_update(self):
466+
if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("bedpostx_gpu") is not None:
467+
self.inputs.num_threads = 1
468+
self._num_threads = self.inputs.num_threads
469+
if not isdefined(self.inputs.num_threads):
470+
if "FSLSUB_PARALLEL" in self.inputs.environ:
471+
del self.inputs.environ["FSLSUB_PARALLEL"]
472+
else:
473+
self.inputs.environ["FSLSUB_PARALLEL"] = str(self.inputs.num_threads)
474+
461475
def _run_interface(self, runtime):
462476
subjectdir = os.path.abspath(self.inputs.out_dir)
463477
if not os.path.exists(subjectdir):
@@ -1024,6 +1038,7 @@ class ProbTrackX2InputSpec(ProbTrackXBaseInputSpec):
10241038
'"vox"'
10251039
),
10261040
)
1041+
use_gpu = traits.Bool(False, desc="Use the GPU version of probtrackx2")
10271042

10281043

10291044
class ProbTrackX2OutputSpec(ProbTrackXOutputSpec):
@@ -1059,9 +1074,20 @@ class ProbTrackX2(ProbTrackX):
10591074
"""
10601075

10611076
_cmd = "probtrackx2"
1077+
_default_cmd = _cmd
10621078
input_spec = ProbTrackX2InputSpec
10631079
output_spec = ProbTrackX2OutputSpec
10641080

1081+
def __init__(self, **inputs):
1082+
super().__init__(**inputs)
1083+
self.inputs.on_trait_change(self._cuda_update, "use_gpu")
1084+
1085+
def _cuda_update(self):
1086+
if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("probtrackx2_gpu") is not None:
1087+
self._cmd = "probtrackx2_gpu"
1088+
else:
1089+
self._cmd = self._default_cmd
1090+
10651091
def _list_outputs(self):
10661092
outputs = super()._list_outputs()
10671093

0 commit comments

Comments
 (0)